Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
b1065aa0
Unverified
Commit
b1065aa0
authored
May 22, 2024
by
Raushan Turganbay
Committed by
GitHub
May 22, 2024
Browse files
Generation: get special tokens from model config (#30899)
* fix * let's do this way? * codestyle * update * add tests
parent
1d568dfa
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
53 additions
and
1 deletion
+53
-1
src/transformers/generation/utils.py
src/transformers/generation/utils.py
+23
-1
tests/generation/test_utils.py
tests/generation/test_utils.py
+30
-0
No files found.
src/transformers/generation/utils.py
View file @
b1065aa0
...
@@ -1361,6 +1361,23 @@ class GenerationMixin:
...
@@ -1361,6 +1361,23 @@ class GenerationMixin:
self
.
_cache
.
reset
()
self
.
_cache
.
reset
()
return
self
.
_cache
return
self
.
_cache
def
_get_decoder_start_token_id
(
self
,
decoder_start_token_id
:
Union
[
int
,
List
[
int
]]
=
None
,
bos_token_id
:
int
=
None
)
->
int
:
decoder_start_token_id
=
(
decoder_start_token_id
if
decoder_start_token_id
is
not
None
else
self
.
generation_config
.
decoder_start_token_id
)
bos_token_id
=
bos_token_id
if
bos_token_id
is
not
None
else
self
.
generation_config
.
bos_token_id
if
decoder_start_token_id
is
not
None
:
return
decoder_start_token_id
elif
bos_token_id
is
not
None
:
return
bos_token_id
else
:
return
def
_prepare_special_tokens
(
def
_prepare_special_tokens
(
self
,
self
,
generation_config
:
GenerationConfig
,
generation_config
:
GenerationConfig
,
...
@@ -1385,11 +1402,16 @@ class GenerationMixin:
...
@@ -1385,11 +1402,16 @@ class GenerationMixin:
return
token
return
token
return
torch
.
tensor
(
token
,
device
=
device
,
dtype
=
torch
.
long
)
return
torch
.
tensor
(
token
,
device
=
device
,
dtype
=
torch
.
long
)
# for BC we also try to get `decoder_start_token_id` from model's generation config (#30892)
if
self
.
config
.
is_encoder_decoder
:
generation_config
.
decoder_start_token_id
=
self
.
_get_decoder_start_token_id
(
generation_config
.
decoder_start_token_id
,
generation_config
.
bos_token_id
)
bos_token_id
=
_tensor_or_none
(
generation_config
.
bos_token_id
,
device
=
device
)
bos_token_id
=
_tensor_or_none
(
generation_config
.
bos_token_id
,
device
=
device
)
eos_token_id
=
_tensor_or_none
(
generation_config
.
eos_token_id
,
device
=
device
)
eos_token_id
=
_tensor_or_none
(
generation_config
.
eos_token_id
,
device
=
device
)
pad_token_id
=
_tensor_or_none
(
generation_config
.
pad_token_id
,
device
=
device
)
pad_token_id
=
_tensor_or_none
(
generation_config
.
pad_token_id
,
device
=
device
)
decoder_start_token_id
=
_tensor_or_none
(
generation_config
.
decoder_start_token_id
,
device
=
device
)
decoder_start_token_id
=
_tensor_or_none
(
generation_config
.
decoder_start_token_id
,
device
=
device
)
decoder_start_token_id
=
decoder_start_token_id
if
decoder_start_token_id
is
not
None
else
bos_token_id
# We can have more than one eos token. Always treat it as a 1D tensor (when it exists).
# We can have more than one eos token. Always treat it as a 1D tensor (when it exists).
if
eos_token_id
is
not
None
and
eos_token_id
.
ndim
==
0
:
if
eos_token_id
is
not
None
and
eos_token_id
.
ndim
==
0
:
...
...
tests/generation/test_utils.py
View file @
b1065aa0
...
@@ -65,6 +65,7 @@ if is_torch_available():
...
@@ -65,6 +65,7 @@ if is_torch_available():
GenerateBeamEncoderDecoderOutput
,
GenerateBeamEncoderDecoderOutput
,
GenerateDecoderOnlyOutput
,
GenerateDecoderOnlyOutput
,
GenerateEncoderDecoderOutput
,
GenerateEncoderDecoderOutput
,
GenerationConfig
,
GreedySearchDecoderOnlyOutput
,
GreedySearchDecoderOnlyOutput
,
GreedySearchEncoderDecoderOutput
,
GreedySearchEncoderDecoderOutput
,
LogitsProcessorList
,
LogitsProcessorList
,
...
@@ -2478,6 +2479,35 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
...
@@ -2478,6 +2479,35 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
self
.
assertListEqual
(
outputs
.
tolist
(),
outputs_batched_ids
.
tolist
())
self
.
assertListEqual
(
outputs
.
tolist
(),
outputs_batched_ids
.
tolist
())
def
test_decoder_start_id_from_config
(
self
):
# Refer to: (#30899)
articles
=
[
"Justin Timberlake and Jessica Biel, welcome to parenthood."
,
"Michael Phelps is arguably the most decorated Olympian of all time."
,
]
bart_tokenizer
=
AutoTokenizer
.
from_pretrained
(
"hf-internal-testing/tiny-random-bart"
)
bart_model
=
BartForConditionalGeneration
.
from_pretrained
(
"hf-internal-testing/tiny-random-bart"
).
to
(
torch_device
)
input_ids
=
bart_tokenizer
(
articles
,
return_tensors
=
"pt"
,
padding
=
True
).
input_ids
.
to
(
torch_device
)
decoder_start_token_id
=
bart_model
.
generation_config
.
decoder_start_token_id
# we should be able to take `decoder_start_token_id` from model's generation config if user passes a `GenerationConfig` type
outputs
=
bart_model
.
generate
(
input_ids
,
generation_config
=
GenerationConfig
(
do_sample
=
False
))
# If the generatoin config has no `decoder_start_token_id` or `bos_token_id`, we will raise an error unless user passes it in config
bart_model
.
generation_config
.
decoder_start_token_id
=
None
bart_model
.
generation_config
.
bos_token_id
=
None
outputs_with_user_id
=
bart_model
.
generate
(
input_ids
,
generation_config
=
GenerationConfig
(
do_sample
=
False
,
decoder_start_token_id
=
decoder_start_token_id
),
)
self
.
assertListEqual
(
outputs
.
tolist
(),
outputs_with_user_id
.
tolist
())
with
self
.
assertRaises
(
ValueError
):
outputs
=
bart_model
.
generate
(
input_ids
,
generation_config
=
GenerationConfig
(
do_sample
=
False
))
def
test_contrastive_search_batched
(
self
):
def
test_contrastive_search_batched
(
self
):
# PT-only test: TF doesn't have constrained beam search
# PT-only test: TF doesn't have constrained beam search
# Tests that contrastive search works with batched inputs (i.e. has the same output as for non-batched inputs)
# Tests that contrastive search works with batched inputs (i.e. has the same output as for non-batched inputs)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment