Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
998b5bb5
Unverified
Commit
998b5bb5
authored
Mar 26, 2024
by
Zhihao Lin
Committed by
GitHub
Mar 26, 2024
Browse files
Allow `bos_token_id is None` during the generation with `inputs_embeds` (#29772)
* update * add ut * update
parent
b9ceb03d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
15 additions
and
3 deletions
+15
-3
src/transformers/generation/utils.py
src/transformers/generation/utils.py
+4
-3
tests/generation/test_utils.py
tests/generation/test_utils.py
+11
-0
No files found.
src/transformers/generation/utils.py
View file @
998b5bb5
...
...
@@ -436,9 +436,6 @@ class GenerationMixin:
shape
=
encoder_outputs
.
last_hidden_state
.
size
()[:
-
1
]
return
torch
.
ones
(
shape
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
*
-
100
if
bos_token_id
is
None
:
raise
ValueError
(
"`bos_token_id` has to be defined when no `input_ids` are provided."
)
# If there is some tensor in `model_kwargs`, we can infer the batch size from it. This is helpful with
# soft-prompting or in multimodal implementations built on top of decoder-only language models.
batch_size
=
1
...
...
@@ -449,6 +446,10 @@ class GenerationMixin:
if
"inputs_embeds"
in
model_kwargs
:
return
torch
.
ones
((
batch_size
,
0
),
dtype
=
torch
.
long
,
device
=
self
.
device
)
if
bos_token_id
is
None
:
raise
ValueError
(
"`bos_token_id` has to be defined when no `input_ids` are provided."
)
return
torch
.
ones
((
batch_size
,
1
),
dtype
=
torch
.
long
,
device
=
self
.
device
)
*
bos_token_id
def
_prepare_attention_mask_for_generation
(
...
...
tests/generation/test_utils.py
View file @
998b5bb5
...
...
@@ -1467,6 +1467,17 @@ class GenerationTesterMixin:
past_kv
[
i
][
1
].
shape
,
(
batch_size
,
num_attention_heads
,
seq_length
,
per_head_embed_dim
)
)
def
test_generate_from_inputs_embeds_with_bos_token_id_is_none
(
self
):
article
=
"Today a dragon flew over Paris."
model
=
AutoModelForCausalLM
.
from_pretrained
(
"hf-internal-testing/tiny-random-gpt2"
).
to
(
torch_device
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"hf-internal-testing/tiny-random-gpt2"
)
input_ids
=
tokenizer
(
article
,
return_tensors
=
"pt"
).
input_ids
.
to
(
torch_device
)
inputs_embeds
=
model
.
get_input_embeddings
()(
input_ids
)
model
.
generate
(
inputs_embeds
=
inputs_embeds
,
max_length
=
20
,
bos_token_id
=
None
)
with
self
.
assertRaises
(
ValueError
):
model
.
generate
(
max_length
=
20
,
bos_token_id
=
None
)
def
test_generate_from_inputs_embeds_decoder_only
(
self
):
# When supported, tests that the decoder model can generate from `inputs_embeds` instead of `input_ids`
# if fails, you should probably update the `prepare_inputs_for_generation` function
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment