Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
beb24f2a
Unverified
Commit
beb24f2a
authored
Jan 05, 2023
by
Joao Gante
Committed by
GitHub
Jan 05, 2023
Browse files
Generate: FLAX infers pad token in its absence and has functional example (#21009)
parent
480799f7
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
26 additions
and
14 deletions
+26
-14
src/transformers/generation/flax_utils.py
src/transformers/generation/flax_utils.py
+20
-9
src/transformers/generation/tf_utils.py
src/transformers/generation/tf_utils.py
+5
-5
utils/documentation_tests.txt
utils/documentation_tests.txt
+1
-0
No files found.
src/transformers/generation/flax_utils.py
View file @
beb24f2a
...
@@ -305,10 +305,10 @@ class FlaxGenerationMixin:
...
@@ -305,10 +305,10 @@ class FlaxGenerationMixin:
>>> model = FlaxAutoModelForCausalLM.from_pretrained("distilgpt2")
>>> model = FlaxAutoModelForCausalLM.from_pretrained("distilgpt2")
>>> input_context = "The dog"
>>> input_context = "The dog"
>>> # encode input context
>>> # encode input context
>>> input
_id
s = tokenizer(input_context, return_tensors="np")
.input_ids
>>> inputs = tokenizer(input_context, return_tensors="np")
>>> # generate candidates using sampling
>>> # generate candidates using sampling
>>> outputs = model.generate(input
_ids=input_id
s, max_length=20, top_k=30, do_sample=True)
>>> outputs = model.generate(
**
inputs, max_length=20, top_k=30, do_sample=True)
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
>>> tokenizer.batch_decode(outputs
.sequences
, skip_special_tokens=True)
```"""
```"""
# Validate the `.generate()` call
# Validate the `.generate()` call
self
.
_validate_model_class
()
self
.
_validate_model_class
()
...
@@ -323,6 +323,17 @@ class FlaxGenerationMixin:
...
@@ -323,6 +323,17 @@ class FlaxGenerationMixin:
)
)
prng_key
=
prng_key
if
prng_key
is
not
None
else
jax
.
random
.
PRNGKey
(
0
)
prng_key
=
prng_key
if
prng_key
is
not
None
else
jax
.
random
.
PRNGKey
(
0
)
if
pad_token_id
is
None
and
eos_token_id
is
not
None
:
if
model_kwargs
.
get
(
"attention_mask"
)
is
None
:
logger
.
warning
(
"The attention mask and the pad token id were not set. As a consequence, you may observe "
"unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
)
if
isinstance
(
eos_token_id
,
list
):
eos_token_id
=
eos_token_id
[
0
]
logger
.
warning
(
f
"Setting `pad_token_id` to `eos_token_id`:
{
eos_token_id
}
for open-end generation."
)
pad_token_id
=
eos_token_id
if
decoder_start_token_id
is
None
and
self
.
config
.
is_encoder_decoder
:
if
decoder_start_token_id
is
None
and
self
.
config
.
is_encoder_decoder
:
raise
ValueError
(
"`decoder_start_token_id` has to be defined for encoder-decoder generation."
)
raise
ValueError
(
"`decoder_start_token_id` has to be defined for encoder-decoder generation."
)
...
@@ -525,8 +536,8 @@ class FlaxGenerationMixin:
...
@@ -525,8 +536,8 @@ class FlaxGenerationMixin:
batch_size
,
cur_len
=
input_ids
.
shape
batch_size
,
cur_len
=
input_ids
.
shape
eos_token_id
=
jnp
.
array
(
eos_token_id
)
eos_token_id
=
jnp
.
array
(
eos_token_id
,
dtype
=
jnp
.
int32
if
eos_token_id
is
not
None
else
None
)
pad_token_id
=
jnp
.
array
(
pad_token_id
)
pad_token_id
=
jnp
.
array
(
pad_token_id
,
dtype
=
jnp
.
int32
)
cur_len
=
jnp
.
array
(
cur_len
)
cur_len
=
jnp
.
array
(
cur_len
)
# per batch-item holding current token in loop.
# per batch-item holding current token in loop.
...
@@ -614,8 +625,8 @@ class FlaxGenerationMixin:
...
@@ -614,8 +625,8 @@ class FlaxGenerationMixin:
batch_size
,
cur_len
=
input_ids
.
shape
batch_size
,
cur_len
=
input_ids
.
shape
eos_token_id
=
jnp
.
array
(
eos_token_id
)
eos_token_id
=
jnp
.
array
(
eos_token_id
,
dtype
=
jnp
.
int32
if
eos_token_id
is
not
None
else
None
)
pad_token_id
=
jnp
.
array
(
pad_token_id
)
pad_token_id
=
jnp
.
array
(
pad_token_id
,
dtype
=
jnp
.
int32
)
cur_len
=
jnp
.
array
(
cur_len
)
cur_len
=
jnp
.
array
(
cur_len
)
# per batch-item holding current token in loop.
# per batch-item holding current token in loop.
...
@@ -748,8 +759,8 @@ class FlaxGenerationMixin:
...
@@ -748,8 +759,8 @@ class FlaxGenerationMixin:
batch_size
,
num_beams
,
cur_len
=
input_ids
.
shape
batch_size
,
num_beams
,
cur_len
=
input_ids
.
shape
eos_token_id
=
jnp
.
array
(
eos_token_id
)
eos_token_id
=
jnp
.
array
(
eos_token_id
,
dtype
=
jnp
.
int32
if
eos_token_id
is
not
None
else
None
)
pad_token_id
=
jnp
.
array
(
pad_token_id
)
pad_token_id
=
jnp
.
array
(
pad_token_id
,
dtype
=
jnp
.
int32
)
cur_len
=
jnp
.
array
(
cur_len
)
cur_len
=
jnp
.
array
(
cur_len
)
# per batch,beam-item holding current token in loop.
# per batch,beam-item holding current token in loop.
...
...
src/transformers/generation/tf_utils.py
View file @
beb24f2a
...
@@ -702,11 +702,11 @@ class TFGenerationMixin:
...
@@ -702,11 +702,11 @@ class TFGenerationMixin:
"The attention mask and the pad token id were not set. As a consequence, you may observe "
"The attention mask and the pad token id were not set. As a consequence, you may observe "
"unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
"unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
)
)
logger
.
warning
(
eos_token_id
=
generation_config
.
eos_token_id
f
"Setting `pad_token_id` to
{
generation_config
.
eos_token_id
}
(first `eos_token_id`) to generate"
if
isinstance
(
eos_token_id
,
list
):
" sequence"
eos_token_id
=
eos_token_id
[
0
]
)
logger
.
warning
(
f
"Setting `pad_token_id` to `eos_token_id`:
{
eos_token_id
}
for open-end generation."
)
generation_config
.
pad_token_id
=
generation_config
.
eos_token_id
generation_config
.
pad_token_id
=
eos_token_id
use_xla
=
not
tf
.
executing_eagerly
()
use_xla
=
not
tf
.
executing_eagerly
()
if
use_xla
and
not
self
.
supports_xla_generation
:
if
use_xla
and
not
self
.
supports_xla_generation
:
...
...
utils/documentation_tests.txt
View file @
beb24f2a
...
@@ -13,6 +13,7 @@ docs/source/en/model_doc/tapex.mdx
...
@@ -13,6 +13,7 @@ docs/source/en/model_doc/tapex.mdx
docs/source/en/model_doc/donut.mdx
docs/source/en/model_doc/donut.mdx
docs/source/en/model_doc/encoder-decoder.mdx
docs/source/en/model_doc/encoder-decoder.mdx
src/transformers/generation/configuration_utils.py
src/transformers/generation/configuration_utils.py
src/transformers/generation/flax_utils.py
src/transformers/generation/tf_utils.py
src/transformers/generation/tf_utils.py
src/transformers/generation/utils.py
src/transformers/generation/utils.py
src/transformers/models/albert/configuration_albert.py
src/transformers/models/albert/configuration_albert.py
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment