Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
b1dbdf22
Unverified
Commit
b1dbdf22
authored
Nov 11, 2021
by
Suraj Patil
Committed by
GitHub
Nov 11, 2021
Browse files
pass params to encode (#14370)
parent
e92190c0
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
3 additions
and
3 deletions
+3
-3
src/transformers/generation_flax_utils.py
src/transformers/generation_flax_utils.py
+3
-3
No files found.
src/transformers/generation_flax_utils.py
View file @
b1dbdf22
...
@@ -132,13 +132,13 @@ class FlaxGenerationMixin:
...
@@ -132,13 +132,13 @@ class FlaxGenerationMixin:
state
=
body_fn
(
state
)
state
=
body_fn
(
state
)
return
state
return
state
def
_prepare_encoder_decoder_kwargs_for_generation
(
self
,
input_ids
,
model_kwargs
):
def
_prepare_encoder_decoder_kwargs_for_generation
(
self
,
input_ids
,
params
,
model_kwargs
):
encoder_kwargs
=
{
encoder_kwargs
=
{
argument
:
value
argument
:
value
for
argument
,
value
in
model_kwargs
.
items
()
for
argument
,
value
in
model_kwargs
.
items
()
if
not
(
argument
.
startswith
(
"decoder_"
)
or
argument
.
startswith
(
"cross_attn"
))
if
not
(
argument
.
startswith
(
"decoder_"
)
or
argument
.
startswith
(
"cross_attn"
))
}
}
model_kwargs
[
"encoder_outputs"
]
=
self
.
encode
(
input_ids
,
return_dict
=
True
,
**
encoder_kwargs
)
model_kwargs
[
"encoder_outputs"
]
=
self
.
encode
(
input_ids
,
params
=
params
,
return_dict
=
True
,
**
encoder_kwargs
)
return
model_kwargs
return
model_kwargs
@
staticmethod
@
staticmethod
...
@@ -251,7 +251,7 @@ class FlaxGenerationMixin:
...
@@ -251,7 +251,7 @@ class FlaxGenerationMixin:
if
self
.
config
.
is_encoder_decoder
:
if
self
.
config
.
is_encoder_decoder
:
# add encoder_outputs to model_kwargs
# add encoder_outputs to model_kwargs
model_kwargs
=
self
.
_prepare_encoder_decoder_kwargs_for_generation
(
input_ids
,
model_kwargs
)
model_kwargs
=
self
.
_prepare_encoder_decoder_kwargs_for_generation
(
input_ids
,
params
,
model_kwargs
)
# prepare decoder_input_ids for generation
# prepare decoder_input_ids for generation
input_ids
=
jnp
.
ones
((
input_ids
.
shape
[
0
],
1
),
dtype
=
"i4"
)
*
decoder_start_token_id
input_ids
=
jnp
.
ones
((
input_ids
.
shape
[
0
],
1
),
dtype
=
"i4"
)
*
decoder_start_token_id
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment