Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
b1dbdf22
"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "7cb1fdd4d109165722858cd626abbfc8d5e2ebc4"
Unverified
Commit
b1dbdf22
authored
Nov 11, 2021
by
Suraj Patil
Committed by
GitHub
Nov 11, 2021
Browse files
pass params to encode (#14370)
parent
e92190c0
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
3 additions
and
3 deletions
+3
-3
src/transformers/generation_flax_utils.py
src/transformers/generation_flax_utils.py
+3
-3
No files found.
src/transformers/generation_flax_utils.py
View file @
b1dbdf22
...
@@ -132,13 +132,13 @@ class FlaxGenerationMixin:
...
@@ -132,13 +132,13 @@ class FlaxGenerationMixin:
state
=
body_fn
(
state
)
state
=
body_fn
(
state
)
return
state
return
state
def
_prepare_encoder_decoder_kwargs_for_generation
(
self
,
input_ids
,
model_kwargs
):
def
_prepare_encoder_decoder_kwargs_for_generation
(
self
,
input_ids
,
params
,
model_kwargs
):
encoder_kwargs
=
{
encoder_kwargs
=
{
argument
:
value
argument
:
value
for
argument
,
value
in
model_kwargs
.
items
()
for
argument
,
value
in
model_kwargs
.
items
()
if
not
(
argument
.
startswith
(
"decoder_"
)
or
argument
.
startswith
(
"cross_attn"
))
if
not
(
argument
.
startswith
(
"decoder_"
)
or
argument
.
startswith
(
"cross_attn"
))
}
}
model_kwargs
[
"encoder_outputs"
]
=
self
.
encode
(
input_ids
,
return_dict
=
True
,
**
encoder_kwargs
)
model_kwargs
[
"encoder_outputs"
]
=
self
.
encode
(
input_ids
,
params
=
params
,
return_dict
=
True
,
**
encoder_kwargs
)
return
model_kwargs
return
model_kwargs
@
staticmethod
@
staticmethod
...
@@ -251,7 +251,7 @@ class FlaxGenerationMixin:
...
@@ -251,7 +251,7 @@ class FlaxGenerationMixin:
if
self
.
config
.
is_encoder_decoder
:
if
self
.
config
.
is_encoder_decoder
:
# add encoder_outputs to model_kwargs
# add encoder_outputs to model_kwargs
model_kwargs
=
self
.
_prepare_encoder_decoder_kwargs_for_generation
(
input_ids
,
model_kwargs
)
model_kwargs
=
self
.
_prepare_encoder_decoder_kwargs_for_generation
(
input_ids
,
params
,
model_kwargs
)
# prepare decoder_input_ids for generation
# prepare decoder_input_ids for generation
input_ids
=
jnp
.
ones
((
input_ids
.
shape
[
0
],
1
),
dtype
=
"i4"
)
*
decoder_start_token_id
input_ids
=
jnp
.
ones
((
input_ids
.
shape
[
0
],
1
),
dtype
=
"i4"
)
*
decoder_start_token_id
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment