Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
b1065aa0
Unverified
Commit
b1065aa0
authored
May 22, 2024
by
Raushan Turganbay
Committed by
GitHub
May 22, 2024
Browse files
Generation: get special tokens from model config (#30899)
* fix * let's do this way? * codestyle * update * add tests
parent
1d568dfa
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
53 additions
and
1 deletion
+53
-1
src/transformers/generation/utils.py
src/transformers/generation/utils.py
+23
-1
tests/generation/test_utils.py
tests/generation/test_utils.py
+30
-0
No files found.
src/transformers/generation/utils.py
View file @
b1065aa0
...
@@ -1361,6 +1361,23 @@ class GenerationMixin:
...
@@ -1361,6 +1361,23 @@ class GenerationMixin:
self
.
_cache
.
reset
()
self
.
_cache
.
reset
()
return
self
.
_cache
return
self
.
_cache
def
_get_decoder_start_token_id
(
self
,
decoder_start_token_id
:
Union
[
int
,
List
[
int
]]
=
None
,
bos_token_id
:
int
=
None
)
->
int
:
decoder_start_token_id
=
(
decoder_start_token_id
if
decoder_start_token_id
is
not
None
else
self
.
generation_config
.
decoder_start_token_id
)
bos_token_id
=
bos_token_id
if
bos_token_id
is
not
None
else
self
.
generation_config
.
bos_token_id
if
decoder_start_token_id
is
not
None
:
return
decoder_start_token_id
elif
bos_token_id
is
not
None
:
return
bos_token_id
else
:
return
def
_prepare_special_tokens
(
def
_prepare_special_tokens
(
self
,
self
,
generation_config
:
GenerationConfig
,
generation_config
:
GenerationConfig
,
...
@@ -1385,11 +1402,16 @@ class GenerationMixin:
...
@@ -1385,11 +1402,16 @@ class GenerationMixin:
return
token
return
token
return
torch
.
tensor
(
token
,
device
=
device
,
dtype
=
torch
.
long
)
return
torch
.
tensor
(
token
,
device
=
device
,
dtype
=
torch
.
long
)
# for BC we also try to get `decoder_start_token_id` from model's generation config (#30892)
if
self
.
config
.
is_encoder_decoder
:
generation_config
.
decoder_start_token_id
=
self
.
_get_decoder_start_token_id
(
generation_config
.
decoder_start_token_id
,
generation_config
.
bos_token_id
)
bos_token_id
=
_tensor_or_none
(
generation_config
.
bos_token_id
,
device
=
device
)
bos_token_id
=
_tensor_or_none
(
generation_config
.
bos_token_id
,
device
=
device
)
eos_token_id
=
_tensor_or_none
(
generation_config
.
eos_token_id
,
device
=
device
)
eos_token_id
=
_tensor_or_none
(
generation_config
.
eos_token_id
,
device
=
device
)
pad_token_id
=
_tensor_or_none
(
generation_config
.
pad_token_id
,
device
=
device
)
pad_token_id
=
_tensor_or_none
(
generation_config
.
pad_token_id
,
device
=
device
)
decoder_start_token_id
=
_tensor_or_none
(
generation_config
.
decoder_start_token_id
,
device
=
device
)
decoder_start_token_id
=
_tensor_or_none
(
generation_config
.
decoder_start_token_id
,
device
=
device
)
decoder_start_token_id
=
decoder_start_token_id
if
decoder_start_token_id
is
not
None
else
bos_token_id
# We can have more than one eos token. Always treat it as a 1D tensor (when it exists).
# We can have more than one eos token. Always treat it as a 1D tensor (when it exists).
if
eos_token_id
is
not
None
and
eos_token_id
.
ndim
==
0
:
if
eos_token_id
is
not
None
and
eos_token_id
.
ndim
==
0
:
...
...
tests/generation/test_utils.py
View file @
b1065aa0
...
@@ -65,6 +65,7 @@ if is_torch_available():
...
@@ -65,6 +65,7 @@ if is_torch_available():
GenerateBeamEncoderDecoderOutput
,
GenerateBeamEncoderDecoderOutput
,
GenerateDecoderOnlyOutput
,
GenerateDecoderOnlyOutput
,
GenerateEncoderDecoderOutput
,
GenerateEncoderDecoderOutput
,
GenerationConfig
,
GreedySearchDecoderOnlyOutput
,
GreedySearchDecoderOnlyOutput
,
GreedySearchEncoderDecoderOutput
,
GreedySearchEncoderDecoderOutput
,
LogitsProcessorList
,
LogitsProcessorList
,
...
@@ -2478,6 +2479,35 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
...
@@ -2478,6 +2479,35 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
self
.
assertListEqual
(
outputs
.
tolist
(),
outputs_batched_ids
.
tolist
())
self
.
assertListEqual
(
outputs
.
tolist
(),
outputs_batched_ids
.
tolist
())
def
test_decoder_start_id_from_config
(
self
):
# Refer to: (#30899)
articles
=
[
"Justin Timberlake and Jessica Biel, welcome to parenthood."
,
"Michael Phelps is arguably the most decorated Olympian of all time."
,
]
bart_tokenizer
=
AutoTokenizer
.
from_pretrained
(
"hf-internal-testing/tiny-random-bart"
)
bart_model
=
BartForConditionalGeneration
.
from_pretrained
(
"hf-internal-testing/tiny-random-bart"
).
to
(
torch_device
)
input_ids
=
bart_tokenizer
(
articles
,
return_tensors
=
"pt"
,
padding
=
True
).
input_ids
.
to
(
torch_device
)
decoder_start_token_id
=
bart_model
.
generation_config
.
decoder_start_token_id
# we should be able to take `decoder_start_token_id` from model's generation config if user passes a `GenerationConfig` type
outputs
=
bart_model
.
generate
(
input_ids
,
generation_config
=
GenerationConfig
(
do_sample
=
False
))
# If the generatoin config has no `decoder_start_token_id` or `bos_token_id`, we will raise an error unless user passes it in config
bart_model
.
generation_config
.
decoder_start_token_id
=
None
bart_model
.
generation_config
.
bos_token_id
=
None
outputs_with_user_id
=
bart_model
.
generate
(
input_ids
,
generation_config
=
GenerationConfig
(
do_sample
=
False
,
decoder_start_token_id
=
decoder_start_token_id
),
)
self
.
assertListEqual
(
outputs
.
tolist
(),
outputs_with_user_id
.
tolist
())
with
self
.
assertRaises
(
ValueError
):
outputs
=
bart_model
.
generate
(
input_ids
,
generation_config
=
GenerationConfig
(
do_sample
=
False
))
def
test_contrastive_search_batched
(
self
):
def
test_contrastive_search_batched
(
self
):
# PT-only test: TF doesn't have constrained beam search
# PT-only test: TF doesn't have constrained beam search
# Tests that contrastive search works with batched inputs (i.e. has the same output as for non-batched inputs)
# Tests that contrastive search works with batched inputs (i.e. has the same output as for non-batched inputs)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment