Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
9dfd6a4b
Unverified
Commit
9dfd6a4b
authored
Apr 13, 2023
by
Joao Gante
Committed by
GitHub
Apr 13, 2023
Browse files
Generate: handle text conditioning with multimodal encoder-decoder models (#22748)
parent
90ce374d
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
123 additions
and
66 deletions
+123
-66
docs/source/en/model_doc/pix2struct.mdx
docs/source/en/model_doc/pix2struct.mdx
+1
-1
src/transformers/generation/tf_utils.py
src/transformers/generation/tf_utils.py
+35
-10
src/transformers/generation/utils.py
src/transformers/generation/utils.py
+37
-14
src/transformers/models/pix2struct/modeling_pix2struct.py
src/transformers/models/pix2struct/modeling_pix2struct.py
+0
-29
tests/generation/test_framework_agnostic.py
tests/generation/test_framework_agnostic.py
+30
-2
tests/generation/test_utils.py
tests/generation/test_utils.py
+20
-10
No files found.
docs/source/en/model_doc/pix2struct.mdx
View file @
9dfd6a4b
...
...
@@ -69,4 +69,4 @@ The original code can be found [here](https://github.com/google-research/pix2str
## Pix2StructForConditionalGeneration
[[autodoc]] Pix2StructForConditionalGeneration
- forward
\ No newline at end of file
- forward
src/transformers/generation/tf_utils.py
View file @
9dfd6a4b
...
...
@@ -837,12 +837,12 @@ class TFGenerationMixin:
# 6. Prepare model inputs which will be used for auto-regressive generation
if
self
.
config
.
is_encoder_decoder
:
# if encoder-decoder then `input_ids` come from `decoder_start_token_id`
input_ids
=
self
.
_prepare_decoder_input_ids_for_generation
(
batch_size
,
input_ids
,
model_kwargs
=
self
.
_prepare_decoder_input_ids_for_generation
(
batch_size
=
batch_size
,
model_input_name
=
model_input_name
,
model_kwargs
=
model_kwargs
,
decoder_start_token_id
=
generation_config
.
decoder_start_token_id
,
bos_token_id
=
generation_config
.
bos_token_id
,
model_kwargs
=
model_kwargs
,
)
else
:
input_ids
=
inputs_tensor
if
model_input_name
==
"input_ids"
else
model_kwargs
.
pop
(
"input_ids"
)
...
...
@@ -1095,16 +1095,41 @@ class TFGenerationMixin:
def
_prepare_decoder_input_ids_for_generation
(
self
,
batch_size
:
int
,
model_input_name
:
str
,
model_kwargs
:
Dict
[
str
,
tf
.
Tensor
],
decoder_start_token_id
:
int
=
None
,
bos_token_id
:
int
=
None
,
model_kwargs
:
Optional
[
Dict
[
str
,
tf
.
Tensor
]]
=
None
,
)
->
tf
.
Tensor
:
# prepare `input_ids` for decoder if model is encoder-decoder
)
->
Tuple
[
tf
.
Tensor
,
Dict
[
str
,
tf
.
Tensor
]]:
"""Prepares `decoder_input_ids` for generation with encoder-decoder models"""
# 1. Check whether the user has defined `decoder_input_ids` manually. To facilitate in terms of input naming,
# we also allow the user to pass it under `input_ids`, if the encoder does not use it as the main input.
if
model_kwargs
is
not
None
and
"decoder_input_ids"
in
model_kwargs
:
return
model_kwargs
.
pop
(
"decoder_input_ids"
)
decoder_input_ids
=
model_kwargs
.
pop
(
"decoder_input_ids"
)
elif
"input_ids"
in
model_kwargs
and
model_input_name
!=
"input_ids"
:
decoder_input_ids
=
model_kwargs
.
pop
(
"input_ids"
)
else
:
decoder_start_token_id
=
self
.
_get_decoder_start_token_id
(
decoder_start_token_id
,
bos_token_id
)
return
tf
.
ones
((
batch_size
,
1
),
dtype
=
tf
.
int32
)
*
decoder_start_token_id
decoder_input_ids
=
None
# 2. Encoder-decoder models expect the `decoder_input_ids` to start with a special token. Let's ensure that.
decoder_start_token_id
=
self
.
_get_decoder_start_token_id
(
decoder_start_token_id
,
bos_token_id
)
decoder_input_ids_start
=
tf
.
ones
((
batch_size
,
1
),
dtype
=
tf
.
int32
)
*
decoder_start_token_id
# no user input -> use decoder_start_token_id as decoder_input_ids
if
decoder_input_ids
is
None
:
decoder_input_ids
=
decoder_input_ids_start
# user input but doesn't start with decoder_start_token_id -> prepend decoder_start_token_id (and adjust
# decoder_attention_mask if provided)
elif
tf
.
reduce_all
(
decoder_input_ids
[:,
0
]
!=
decoder_start_token_id
):
decoder_input_ids
=
tf
.
concat
([
decoder_input_ids_start
,
decoder_input_ids
],
axis
=-
1
)
if
"decoder_attention_mask"
in
model_kwargs
:
decoder_attention_mask
=
model_kwargs
[
"decoder_attention_mask"
]
decoder_attention_mask
=
tf
.
concat
(
(
tf
.
ones_like
(
decoder_attention_mask
)[:,
:
1
],
decoder_attention_mask
),
axis
=-
1
,
)
model_kwargs
[
"decoder_attention_mask"
]
=
decoder_attention_mask
return
decoder_input_ids
,
model_kwargs
def
_get_decoder_start_token_id
(
self
,
decoder_start_token_id
:
int
=
None
,
bos_token_id
:
int
=
None
)
->
int
:
# retrieve decoder_start_token_id for encoder-decoder models
...
...
src/transformers/generation/utils.py
View file @
9dfd6a4b
...
...
@@ -642,18 +642,44 @@ class GenerationMixin:
def
_prepare_decoder_input_ids_for_generation
(
self
,
batch_size
:
int
,
model_input_name
:
str
,
model_kwargs
:
Dict
[
str
,
torch
.
Tensor
],
decoder_start_token_id
:
int
=
None
,
bos_token_id
:
int
=
None
,
model_kwargs
:
Optional
[
Dict
[
str
,
torch
.
Tensor
]]
=
None
,
device
:
torch
.
device
=
None
,
)
->
torch
.
LongTensor
:
)
->
Tuple
[
torch
.
LongTensor
,
Dict
[
str
,
torch
.
Tensor
]]:
"""Prepares `decoder_input_ids` for generation with encoder-decoder models"""
# 1. Check whether the user has defined `decoder_input_ids` manually. To facilitate in terms of input naming,
# we also allow the user to pass it under `input_ids`, if the encoder does not use it as the main input.
if
model_kwargs
is
not
None
and
"decoder_input_ids"
in
model_kwargs
:
return
model_kwargs
.
pop
(
"decoder_input_ids"
)
decoder_input_ids
=
model_kwargs
.
pop
(
"decoder_input_ids"
)
elif
"input_ids"
in
model_kwargs
and
model_input_name
!=
"input_ids"
:
decoder_input_ids
=
model_kwargs
.
pop
(
"input_ids"
)
else
:
decoder_start_token_id
=
self
.
_get_decoder_start_token_id
(
decoder_start_token_id
,
bos_token_id
)
if
device
is
None
:
device
=
self
.
device
return
torch
.
ones
((
batch_size
,
1
),
dtype
=
torch
.
long
,
device
=
device
)
*
decoder_start_token_id
decoder_input_ids
=
None
# 2. Encoder-decoder models expect the `decoder_input_ids` to start with a special token. Let's ensure that.
decoder_start_token_id
=
self
.
_get_decoder_start_token_id
(
decoder_start_token_id
,
bos_token_id
)
if
device
is
None
:
device
=
self
.
device
decoder_input_ids_start
=
torch
.
ones
((
batch_size
,
1
),
dtype
=
torch
.
long
,
device
=
device
)
*
decoder_start_token_id
# no user input -> use decoder_start_token_id as decoder_input_ids
if
decoder_input_ids
is
None
:
decoder_input_ids
=
decoder_input_ids_start
# user input but doesn't start with decoder_start_token_id -> prepend decoder_start_token_id (and adjust
# decoder_attention_mask if provided)
elif
(
decoder_input_ids
[:,
0
]
!=
decoder_start_token_id
).
all
().
item
():
decoder_input_ids
=
torch
.
cat
([
decoder_input_ids_start
,
decoder_input_ids
],
dim
=-
1
)
if
"decoder_attention_mask"
in
model_kwargs
:
decoder_attention_mask
=
model_kwargs
[
"decoder_attention_mask"
]
decoder_attention_mask
=
torch
.
cat
(
(
torch
.
ones_like
(
decoder_attention_mask
)[:,
:
1
],
decoder_attention_mask
),
dim
=-
1
,
)
model_kwargs
[
"decoder_attention_mask"
]
=
decoder_attention_mask
return
decoder_input_ids
,
model_kwargs
def
_get_decoder_start_token_id
(
self
,
decoder_start_token_id
:
int
=
None
,
bos_token_id
:
int
=
None
)
->
int
:
decoder_start_token_id
=
(
...
...
@@ -1289,17 +1315,14 @@ class GenerationMixin:
# 5. Prepare `input_ids` which will be used for auto-regressive generation
if
self
.
config
.
is_encoder_decoder
:
input_ids
=
self
.
_prepare_decoder_input_ids_for_generation
(
batch_size
,
input_ids
,
model_kwargs
=
self
.
_prepare_decoder_input_ids_for_generation
(
batch_size
=
batch_size
,
model_input_name
=
model_input_name
,
model_kwargs
=
model_kwargs
,
decoder_start_token_id
=
generation_config
.
decoder_start_token_id
,
bos_token_id
=
generation_config
.
bos_token_id
,
model_kwargs
=
model_kwargs
,
device
=
inputs_tensor
.
device
,
)
# conditional generation for multi-modal models.
if
"input_ids"
in
model_kwargs
and
model_input_name
==
"pixel_values"
:
input_ids
=
torch
.
cat
([
input_ids
,
model_kwargs
.
pop
(
"input_ids"
)],
dim
=-
1
)
else
:
input_ids
=
inputs_tensor
if
model_input_name
==
"input_ids"
else
model_kwargs
.
pop
(
"input_ids"
)
...
...
src/transformers/models/pix2struct/modeling_pix2struct.py
View file @
9dfd6a4b
...
...
@@ -1776,35 +1776,6 @@ class Pix2StructForConditionalGeneration(Pix2StructPreTrainedModel):
encoder_outputs
=
None
,
**
kwargs
,
):
if
isinstance
(
input_ids
,
torch
.
Tensor
):
# check if the first element of `input_ids` is equal to `input_ids`:
if
(
input_ids
[:,
0
]
!=
self
.
config
.
decoder_start_token_id
).
all
().
item
():
# add `input_ids` as first token to `input_ids`
input_ids
=
torch
.
cat
(
[
torch
.
ones
((
input_ids
.
shape
[
0
],
1
),
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
*
self
.
config
.
decoder_start_token_id
,
input_ids
,
],
dim
=-
1
,
)
if
decoder_attention_mask
is
not
None
:
decoder_attention_mask
=
torch
.
cat
(
[
torch
.
ones
(
(
decoder_attention_mask
.
shape
[
0
],
1
),
dtype
=
torch
.
long
,
device
=
decoder_attention_mask
.
device
,
),
decoder_attention_mask
,
],
dim
=-
1
,
)
elif
input_ids
is
None
:
batch_size
=
flattened_patches
.
shape
[
0
]
input_ids
=
torch
.
LongTensor
([[
self
.
input_ids
]]).
repeat
(
batch_size
,
1
).
to
(
input_ids
.
device
)
if
decoder_attention_mask
is
None
:
decoder_attention_mask
=
torch
.
ones_like
(
input_ids
).
to
(
input_ids
.
device
)
...
...
tests/generation/test_framework_agnostic.py
View file @
9dfd6a4b
...
...
@@ -94,8 +94,8 @@ class GenerationIntegrationTestsMixin:
# Decoder only call
outputs
=
bart_model
.
generate
(
decoder_input_ids
=
input_ids
,
max_new_tokens
=
max_new_tokens
)
#
29
+ 3 new tokens
self
.
assertEqual
(
list
(
outputs
.
shape
),
[
1
,
3
2
])
#
1 BOS + 29 (input length)
+ 3 new tokens
self
.
assertEqual
(
list
(
outputs
.
shape
),
[
1
,
3
3
])
# Encoder decoder call > 20
outputs
=
bart_model
.
generate
(
max_new_tokens
=
max_new_tokens
+
20
)
...
...
@@ -658,3 +658,31 @@ class GenerationIntegrationTestsMixin:
[
token
==
model
.
config
.
pad_token_id
for
token
in
generated_tokens
[
0
][
expectation
:]]
)
self
.
assertTrue
(
unpadded_correct_condition
or
padded_correct_condition
)
def
test_generate_vision2text_conditioning
(
self
):
model_cls
=
self
.
framework_dependent_parameters
[
"AutoModelForVision2Seq"
]
floats_tensor
=
self
.
framework_dependent_parameters
[
"floats_tensor"
]
create_tensor_fn
=
self
.
framework_dependent_parameters
[
"create_tensor_fn"
]
is_pt
=
not
model_cls
.
__name__
.
startswith
(
"TF"
)
pixel_values
=
floats_tensor
((
2
,
3
,
30
,
30
))
conditioning_input
=
create_tensor_fn
([[
10
],
[
10
]])
# this should be the 2nd output token, after the BOS token
model
=
model_cls
.
from_pretrained
(
"hf-internal-testing/tiny-random-VisionEncoderDecoderModel-vit-gpt2"
)
if
is_pt
:
pixel_values
=
pixel_values
.
to
(
torch_device
)
model
=
model
.
to
(
torch_device
)
conditioning_input
=
conditioning_input
.
to
(
torch_device
)
# we can condition on decoder_input_ids (expected decoder input) and input_ids (which we pipe internally as
# decoder_input_ids, if the encoder is not a model with text input)
output_sequences_decoder_input_ids
=
model
.
generate
(
pixel_values
,
max_length
=
5
,
decoder_input_ids
=
conditioning_input
)
output_sequences_input_ids
=
model
.
generate
(
pixel_values
,
max_length
=
5
,
input_ids
=
conditioning_input
)
if
is_pt
:
output_sequences_decoder_input_ids
=
output_sequences_decoder_input_ids
.
cpu
().
numpy
()
output_sequences_input_ids
=
output_sequences_input_ids
.
cpu
().
numpy
()
conditioning_input
=
conditioning_input
.
cpu
().
numpy
()
self
.
assertTrue
(
np
.
array_equal
(
output_sequences_decoder_input_ids
,
output_sequences_input_ids
))
self
.
assertTrue
(
np
.
array_equal
(
output_sequences_decoder_input_ids
[:,
1
:
2
],
conditioning_input
))
tests/generation/test_utils.py
View file @
9dfd6a4b
...
...
@@ -1892,8 +1892,10 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
max_length
=
20
input_ids
=
input_ids
.
expand
(
2
,
-
1
)
model_kwargs
=
bart_model
.
_prepare_encoder_decoder_kwargs_for_generation
(
input_ids
,
{})
input_ids
=
bart_model
.
_prepare_decoder_input_ids_for_generation
(
input_ids
.
shape
[
0
],
input_ids
,
model_kwargs
=
bart_model
.
_prepare_decoder_input_ids_for_generation
(
batch_size
=
input_ids
.
shape
[
0
],
model_input_name
=
bart_model
.
main_input_name
,
model_kwargs
=
model_kwargs
,
decoder_start_token_id
=
bart_model
.
config
.
decoder_start_token_id
,
bos_token_id
=
bart_model
.
config
.
bos_token_id
,
)
...
...
@@ -1919,8 +1921,10 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
max_length
=
20
input_ids
=
input_ids
.
expand
(
2
,
-
1
)
model_kwargs
=
bart_model
.
_prepare_encoder_decoder_kwargs_for_generation
(
input_ids
,
{})
input_ids
=
bart_model
.
_prepare_decoder_input_ids_for_generation
(
input_ids
.
shape
[
0
],
input_ids
,
model_kwargs
=
bart_model
.
_prepare_decoder_input_ids_for_generation
(
batch_size
=
input_ids
.
shape
[
0
],
model_input_name
=
bart_model
.
main_input_name
,
model_kwargs
=
model_kwargs
,
decoder_start_token_id
=
bart_model
.
config
.
decoder_start_token_id
,
bos_token_id
=
bart_model
.
config
.
bos_token_id
,
)
...
...
@@ -1949,8 +1953,10 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
input_ids
=
input_ids
.
expand
(
2
,
-
1
)
model_kwargs
=
bart_model
.
_prepare_encoder_decoder_kwargs_for_generation
(
input_ids
,
{})
input_ids
=
bart_model
.
_prepare_decoder_input_ids_for_generation
(
input_ids
.
shape
[
0
],
input_ids
,
model_kwargs
=
bart_model
.
_prepare_decoder_input_ids_for_generation
(
batch_size
=
input_ids
.
shape
[
0
],
model_input_name
=
bart_model
.
main_input_name
,
model_kwargs
=
model_kwargs
,
decoder_start_token_id
=
bart_model
.
config
.
decoder_start_token_id
,
bos_token_id
=
bart_model
.
config
.
bos_token_id
,
)
...
...
@@ -1982,8 +1988,10 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
input_ids
=
input_ids
.
expand
(
6
,
-
1
)
model_kwargs
=
bart_model
.
_prepare_encoder_decoder_kwargs_for_generation
(
input_ids
,
{})
input_ids
=
bart_model
.
_prepare_decoder_input_ids_for_generation
(
input_ids
.
shape
[
0
],
input_ids
,
model_kwargs
=
bart_model
.
_prepare_decoder_input_ids_for_generation
(
batch_size
=
input_ids
.
shape
[
0
],
model_input_name
=
bart_model
.
main_input_name
,
model_kwargs
=
model_kwargs
,
decoder_start_token_id
=
bart_model
.
config
.
decoder_start_token_id
,
bos_token_id
=
bart_model
.
config
.
bos_token_id
,
)
...
...
@@ -2021,8 +2029,10 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
# Greedy
input_ids
=
input_ids
.
expand
(
6
,
-
1
)
model_kwargs
=
bart_model
.
_prepare_encoder_decoder_kwargs_for_generation
(
input_ids
,
{})
input_ids
=
bart_model
.
_prepare_decoder_input_ids_for_generation
(
input_ids
.
shape
[
0
],
input_ids
,
model_kwargs
=
bart_model
.
_prepare_decoder_input_ids_for_generation
(
batch_size
=
input_ids
.
shape
[
0
],
model_input_name
=
bart_model
.
main_input_name
,
model_kwargs
=
model_kwargs
,
decoder_start_token_id
=
bart_model
.
config
.
decoder_start_token_id
,
bos_token_id
=
bart_model
.
config
.
bos_token_id
,
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment