Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
998b5bb5
Unverified
Commit
998b5bb5
authored
Mar 26, 2024
by
Zhihao Lin
Committed by
GitHub
Mar 26, 2024
Browse files
Allow `bos_token_id is None` during the generation with `inputs_embeds` (#29772)
* update * add ut * update
parent
b9ceb03d
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
15 additions
and
3 deletions
+15
-3
src/transformers/generation/utils.py
src/transformers/generation/utils.py
+4
-3
tests/generation/test_utils.py
tests/generation/test_utils.py
+11
-0
No files found.
src/transformers/generation/utils.py
View file @
998b5bb5
...
@@ -436,9 +436,6 @@ class GenerationMixin:
...
@@ -436,9 +436,6 @@ class GenerationMixin:
shape
=
encoder_outputs
.
last_hidden_state
.
size
()[:
-
1
]
shape
=
encoder_outputs
.
last_hidden_state
.
size
()[:
-
1
]
return
torch
.
ones
(
shape
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
*
-
100
return
torch
.
ones
(
shape
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
*
-
100
if
bos_token_id
is
None
:
raise
ValueError
(
"`bos_token_id` has to be defined when no `input_ids` are provided."
)
# If there is some tensor in `model_kwargs`, we can infer the batch size from it. This is helpful with
# If there is some tensor in `model_kwargs`, we can infer the batch size from it. This is helpful with
# soft-prompting or in multimodal implementations built on top of decoder-only language models.
# soft-prompting or in multimodal implementations built on top of decoder-only language models.
batch_size
=
1
batch_size
=
1
...
@@ -449,6 +446,10 @@ class GenerationMixin:
...
@@ -449,6 +446,10 @@ class GenerationMixin:
if
"inputs_embeds"
in
model_kwargs
:
if
"inputs_embeds"
in
model_kwargs
:
return
torch
.
ones
((
batch_size
,
0
),
dtype
=
torch
.
long
,
device
=
self
.
device
)
return
torch
.
ones
((
batch_size
,
0
),
dtype
=
torch
.
long
,
device
=
self
.
device
)
if
bos_token_id
is
None
:
raise
ValueError
(
"`bos_token_id` has to be defined when no `input_ids` are provided."
)
return
torch
.
ones
((
batch_size
,
1
),
dtype
=
torch
.
long
,
device
=
self
.
device
)
*
bos_token_id
return
torch
.
ones
((
batch_size
,
1
),
dtype
=
torch
.
long
,
device
=
self
.
device
)
*
bos_token_id
def
_prepare_attention_mask_for_generation
(
def
_prepare_attention_mask_for_generation
(
...
...
tests/generation/test_utils.py
View file @
998b5bb5
...
@@ -1467,6 +1467,17 @@ class GenerationTesterMixin:
...
@@ -1467,6 +1467,17 @@ class GenerationTesterMixin:
past_kv
[
i
][
1
].
shape
,
(
batch_size
,
num_attention_heads
,
seq_length
,
per_head_embed_dim
)
past_kv
[
i
][
1
].
shape
,
(
batch_size
,
num_attention_heads
,
seq_length
,
per_head_embed_dim
)
)
)
def
test_generate_from_inputs_embeds_with_bos_token_id_is_none
(
self
):
article
=
"Today a dragon flew over Paris."
model
=
AutoModelForCausalLM
.
from_pretrained
(
"hf-internal-testing/tiny-random-gpt2"
).
to
(
torch_device
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"hf-internal-testing/tiny-random-gpt2"
)
input_ids
=
tokenizer
(
article
,
return_tensors
=
"pt"
).
input_ids
.
to
(
torch_device
)
inputs_embeds
=
model
.
get_input_embeddings
()(
input_ids
)
model
.
generate
(
inputs_embeds
=
inputs_embeds
,
max_length
=
20
,
bos_token_id
=
None
)
with
self
.
assertRaises
(
ValueError
):
model
.
generate
(
max_length
=
20
,
bos_token_id
=
None
)
def
test_generate_from_inputs_embeds_decoder_only
(
self
):
def
test_generate_from_inputs_embeds_decoder_only
(
self
):
# When supported, tests that the decoder model can generate from `inputs_embeds` instead of `input_ids`
# When supported, tests that the decoder model can generate from `inputs_embeds` instead of `input_ids`
# if fails, you should probably update the `prepare_inputs_for_generation` function
# if fails, you should probably update the `prepare_inputs_for_generation` function
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment