Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
1d4b7978
Unverified
Commit
1d4b7978
authored
Feb 23, 2023
by
Joao Gante
Committed by
GitHub
Feb 23, 2023
Browse files
Generate: Fix GIT batched captioning (#21738)
parent
78a93d17
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
44 additions
and
14 deletions
+44
-14
src/transformers/generation/tf_utils.py
src/transformers/generation/tf_utils.py
+11
-7
src/transformers/generation/utils.py
src/transformers/generation/utils.py
+11
-7
tests/models/git/test_modeling_git.py
tests/models/git/test_modeling_git.py
+22
-0
No files found.
src/transformers/generation/tf_utils.py
View file @
1d4b7978
...
...
@@ -1217,7 +1217,7 @@ class TFGenerationMixin:
# In this case, `input_ids` is moved to the `model_kwargs`, so a few automations (like the creation of
# the attention mask) can rely on the actual model input.
model_kwargs
[
"input_ids"
]
=
self
.
_maybe_initialize_input_ids_for_generation
(
inputs
,
bos_token_id
,
batch_size
=
model_kwargs
[
"inputs_embeds"
].
shape
[
0
]
inputs
,
bos_token_id
,
model_kwargs
=
model_kwargs
)
else
:
if
inputs
is
not
None
:
...
...
@@ -1225,9 +1225,7 @@ class TFGenerationMixin:
inputs
,
input_name
=
model_kwargs
[
"inputs_embeds"
],
"inputs_embeds"
# 4. if `inputs` is still None, try to create `input_ids` from BOS token
inputs
=
self
.
_maybe_initialize_input_ids_for_generation
(
inputs
,
bos_token_id
,
model_kwargs
.
get
(
"encoder_outputs"
)
)
inputs
=
self
.
_maybe_initialize_input_ids_for_generation
(
inputs
,
bos_token_id
,
model_kwargs
)
return
inputs
,
input_name
,
model_kwargs
...
...
@@ -1235,13 +1233,13 @@ class TFGenerationMixin:
self
,
inputs
:
Optional
[
tf
.
Tensor
]
=
None
,
bos_token_id
:
Optional
[
int
]
=
None
,
encoder_outputs
:
Optional
[
ModelOutput
]
=
None
,
batch_size
:
Optional
[
int
]
=
None
,
model_kwargs
:
Optional
[
Dict
[
str
,
tf
.
Tensor
]]
=
None
,
)
->
tf
.
Tensor
:
"""Initializes input ids for generation, if necessary."""
if
inputs
is
not
None
:
return
inputs
encoder_outputs
=
model_kwargs
.
get
(
"encoder_outputs"
)
if
self
.
config
.
is_encoder_decoder
and
encoder_outputs
is
not
None
:
# make dummy input_ids with value -100, as a sanity check ensuring that they won't be used for encoding
shape
=
encoder_outputs
.
last_hidden_state
.
shape
[:
-
1
]
...
...
@@ -1250,7 +1248,13 @@ class TFGenerationMixin:
if
bos_token_id
is
None
:
raise
ValueError
(
"`bos_token_id` has to be defined when no `input_ids` are provided."
)
batch_size
=
batch_size
if
batch_size
is
not
None
else
1
# If there is some tensor in `model_kwargs`, we can infer the batch size from it. This is helpful with
# soft-prompting or in multimodal implementations built on top of decoder-only language models.
batch_size
=
1
for
value
in
model_kwargs
.
values
():
if
isinstance
(
value
,
tf
.
Tensor
):
batch_size
=
value
.
shape
[
0
]
break
return
tf
.
ones
((
batch_size
,
1
),
dtype
=
tf
.
int32
)
*
bos_token_id
@
staticmethod
...
...
src/transformers/generation/utils.py
View file @
1d4b7978
...
...
@@ -544,7 +544,7 @@ class GenerationMixin:
# In this case, `input_ids` is moved to the `model_kwargs`, so a few automations (like the creation of
# the attention mask) can rely on the actual model input.
model_kwargs
[
"input_ids"
]
=
self
.
_maybe_initialize_input_ids_for_generation
(
inputs
,
bos_token_id
,
batch_size
=
model_kwargs
[
"inputs_embeds"
].
shape
[
0
]
inputs
,
bos_token_id
,
model_kwargs
=
model_kwargs
)
else
:
if
inputs
is
not
None
:
...
...
@@ -552,9 +552,7 @@ class GenerationMixin:
inputs
,
input_name
=
model_kwargs
[
"inputs_embeds"
],
"inputs_embeds"
# 4. if `inputs` is still None, try to create `input_ids` from BOS token
inputs
=
self
.
_maybe_initialize_input_ids_for_generation
(
inputs
,
bos_token_id
,
model_kwargs
.
get
(
"encoder_outputs"
)
)
inputs
=
self
.
_maybe_initialize_input_ids_for_generation
(
inputs
,
bos_token_id
,
model_kwargs
)
return
inputs
,
input_name
,
model_kwargs
def
adjust_logits_during_generation
(
self
,
logits
:
torch
.
FloatTensor
,
**
kwargs
)
->
torch
.
FloatTensor
:
...
...
@@ -567,13 +565,13 @@ class GenerationMixin:
self
,
inputs
:
Optional
[
torch
.
Tensor
]
=
None
,
bos_token_id
:
Optional
[
int
]
=
None
,
encoder_outputs
:
Optional
[
ModelOutput
]
=
None
,
batch_size
:
Optional
[
int
]
=
None
,
model_kwargs
:
Optional
[
Dict
[
str
,
torch
.
Tensor
]]
=
None
,
)
->
torch
.
LongTensor
:
"""Initializes input ids for generation, if necessary."""
if
inputs
is
not
None
:
return
inputs
encoder_outputs
=
model_kwargs
.
get
(
"encoder_outputs"
)
if
self
.
config
.
is_encoder_decoder
and
encoder_outputs
is
not
None
:
# make dummy input_ids with value -100, as a sanity check ensuring that they won't be used for encoding
shape
=
encoder_outputs
.
last_hidden_state
.
size
()[:
-
1
]
...
...
@@ -582,7 +580,13 @@ class GenerationMixin:
if
bos_token_id
is
None
:
raise
ValueError
(
"`bos_token_id` has to be defined when no `input_ids` are provided."
)
batch_size
=
batch_size
if
batch_size
is
not
None
else
1
# If there is some tensor in `model_kwargs`, we can infer the batch size from it. This is helpful with
# soft-prompting or in multimodal implementations built on top of decoder-only language models.
batch_size
=
1
for
value
in
model_kwargs
.
values
():
if
isinstance
(
value
,
torch
.
Tensor
):
batch_size
=
value
.
shape
[
0
]
break
return
torch
.
ones
((
batch_size
,
1
),
dtype
=
torch
.
long
,
device
=
self
.
device
)
*
bos_token_id
def
_prepare_attention_mask_for_generation
(
...
...
tests/models/git/test_modeling_git.py
View file @
1d4b7978
...
...
@@ -340,6 +340,24 @@ class GitModelTester:
self
.
parent
.
assertEqual
(
generated_ids
.
shape
,
(
self
.
batch_size
*
2
,
20
))
def
_test_batched_generate_captioning
(
self
,
config
,
input_ids
,
input_mask
,
pixel_values
):
model
=
GitForCausalLM
(
config
=
config
)
model
.
to
(
torch_device
)
model
.
eval
()
# generate
generated_ids
=
model
.
generate
(
input_ids
=
None
,
# captioning -> no input_ids
attention_mask
=
None
,
pixel_values
=
pixel_values
,
do_sample
=
False
,
max_length
=
20
,
num_beams
=
2
,
num_return_sequences
=
2
,
)
self
.
parent
.
assertEqual
(
generated_ids
.
shape
,
(
self
.
batch_size
*
2
,
20
))
def
prepare_config_and_inputs_for_common
(
self
):
config_and_inputs
=
self
.
prepare_config_and_inputs
()
...
...
@@ -398,6 +416,10 @@ class GitModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
_test_beam_search_generate
(
*
config_and_inputs
)
def
test_batched_generate_captioning
(
self
):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
_test_batched_generate_captioning
(
*
config_and_inputs
)
def
test_model_various_embeddings
(
self
):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
for
type
in
[
"absolute"
,
"relative_key"
,
"relative_key_query"
]:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment