Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
9dfd6a4b
Unverified
Commit
9dfd6a4b
authored
Apr 13, 2023
by
Joao Gante
Committed by
GitHub
Apr 13, 2023
Browse files
Generate: handle text conditioning with multimodal encoder-decoder models (#22748)
parent
90ce374d
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
123 additions
and
66 deletions
+123
-66
docs/source/en/model_doc/pix2struct.mdx
docs/source/en/model_doc/pix2struct.mdx
+1
-1
src/transformers/generation/tf_utils.py
src/transformers/generation/tf_utils.py
+35
-10
src/transformers/generation/utils.py
src/transformers/generation/utils.py
+37
-14
src/transformers/models/pix2struct/modeling_pix2struct.py
src/transformers/models/pix2struct/modeling_pix2struct.py
+0
-29
tests/generation/test_framework_agnostic.py
tests/generation/test_framework_agnostic.py
+30
-2
tests/generation/test_utils.py
tests/generation/test_utils.py
+20
-10
No files found.
docs/source/en/model_doc/pix2struct.mdx
View file @
9dfd6a4b
src/transformers/generation/tf_utils.py
View file @
9dfd6a4b
...
...
@@ -837,12 +837,12 @@ class TFGenerationMixin:
# 6. Prepare model inputs which will be used for auto-regressive generation
if
self
.
config
.
is_encoder_decoder
:
# if encoder-decoder then `input_ids` come from `decoder_start_token_id`
input_ids
=
self
.
_prepare_decoder_input_ids_for_generation
(
batch_size
,
input_ids
,
model_kwargs
=
self
.
_prepare_decoder_input_ids_for_generation
(
batch_size
=
batch_size
,
model_input_name
=
model_input_name
,
model_kwargs
=
model_kwargs
,
decoder_start_token_id
=
generation_config
.
decoder_start_token_id
,
bos_token_id
=
generation_config
.
bos_token_id
,
model_kwargs
=
model_kwargs
,
)
else
:
input_ids
=
inputs_tensor
if
model_input_name
==
"input_ids"
else
model_kwargs
.
pop
(
"input_ids"
)
...
...
@@ -1095,16 +1095,41 @@ class TFGenerationMixin:
def
_prepare_decoder_input_ids_for_generation
(
self
,
batch_size
:
int
,
model_input_name
:
str
,
model_kwargs
:
Dict
[
str
,
tf
.
Tensor
],
decoder_start_token_id
:
int
=
None
,
bos_token_id
:
int
=
None
,
model_kwargs
:
Optional
[
Dict
[
str
,
tf
.
Tensor
]]
=
None
,
)
->
tf
.
Tensor
:
# prepare `input_ids` for decoder if model is encoder-decoder
)
->
Tuple
[
tf
.
Tensor
,
Dict
[
str
,
tf
.
Tensor
]]:
"""Prepares `decoder_input_ids` for generation with encoder-decoder models"""
# 1. Check whether the user has defined `decoder_input_ids` manually. To facilitate in terms of input naming,
# we also allow the user to pass it under `input_ids`, if the encoder does not use it as the main input.
if
model_kwargs
is
not
None
and
"decoder_input_ids"
in
model_kwargs
:
return
model_kwargs
.
pop
(
"decoder_input_ids"
)
decoder_input_ids
=
model_kwargs
.
pop
(
"decoder_input_ids"
)
elif
"input_ids"
in
model_kwargs
and
model_input_name
!=
"input_ids"
:
decoder_input_ids
=
model_kwargs
.
pop
(
"input_ids"
)
else
:
decoder_input_ids
=
None
# 2. Encoder-decoder models expect the `decoder_input_ids` to start with a special token. Let's ensure that.
decoder_start_token_id
=
self
.
_get_decoder_start_token_id
(
decoder_start_token_id
,
bos_token_id
)
return
tf
.
ones
((
batch_size
,
1
),
dtype
=
tf
.
int32
)
*
decoder_start_token_id
decoder_input_ids_start
=
tf
.
ones
((
batch_size
,
1
),
dtype
=
tf
.
int32
)
*
decoder_start_token_id
# no user input -> use decoder_start_token_id as decoder_input_ids
if
decoder_input_ids
is
None
:
decoder_input_ids
=
decoder_input_ids_start
# user input but doesn't start with decoder_start_token_id -> prepend decoder_start_token_id (and adjust
# decoder_attention_mask if provided)
elif
tf
.
reduce_all
(
decoder_input_ids
[:,
0
]
!=
decoder_start_token_id
):
decoder_input_ids
=
tf
.
concat
([
decoder_input_ids_start
,
decoder_input_ids
],
axis
=-
1
)
if
"decoder_attention_mask"
in
model_kwargs
:
decoder_attention_mask
=
model_kwargs
[
"decoder_attention_mask"
]
decoder_attention_mask
=
tf
.
concat
(
(
tf
.
ones_like
(
decoder_attention_mask
)[:,
:
1
],
decoder_attention_mask
),
axis
=-
1
,
)
model_kwargs
[
"decoder_attention_mask"
]
=
decoder_attention_mask
return
decoder_input_ids
,
model_kwargs
def
_get_decoder_start_token_id
(
self
,
decoder_start_token_id
:
int
=
None
,
bos_token_id
:
int
=
None
)
->
int
:
# retrieve decoder_start_token_id for encoder-decoder models
...
...
src/transformers/generation/utils.py
View file @
9dfd6a4b
...
...
@@ -642,18 +642,44 @@ class GenerationMixin:
def
_prepare_decoder_input_ids_for_generation
(
self
,
batch_size
:
int
,
model_input_name
:
str
,
model_kwargs
:
Dict
[
str
,
torch
.
Tensor
],
decoder_start_token_id
:
int
=
None
,
bos_token_id
:
int
=
None
,
model_kwargs
:
Optional
[
Dict
[
str
,
torch
.
Tensor
]]
=
None
,
device
:
torch
.
device
=
None
,
)
->
torch
.
LongTensor
:
)
->
Tuple
[
torch
.
LongTensor
,
Dict
[
str
,
torch
.
Tensor
]]:
"""Prepares `decoder_input_ids` for generation with encoder-decoder models"""
# 1. Check whether the user has defined `decoder_input_ids` manually. To facilitate in terms of input naming,
# we also allow the user to pass it under `input_ids`, if the encoder does not use it as the main input.
if
model_kwargs
is
not
None
and
"decoder_input_ids"
in
model_kwargs
:
return
model_kwargs
.
pop
(
"decoder_input_ids"
)
decoder_input_ids
=
model_kwargs
.
pop
(
"decoder_input_ids"
)
elif
"input_ids"
in
model_kwargs
and
model_input_name
!=
"input_ids"
:
decoder_input_ids
=
model_kwargs
.
pop
(
"input_ids"
)
else
:
decoder_input_ids
=
None
# 2. Encoder-decoder models expect the `decoder_input_ids` to start with a special token. Let's ensure that.
decoder_start_token_id
=
self
.
_get_decoder_start_token_id
(
decoder_start_token_id
,
bos_token_id
)
if
device
is
None
:
device
=
self
.
device
return
torch
.
ones
((
batch_size
,
1
),
dtype
=
torch
.
long
,
device
=
device
)
*
decoder_start_token_id
decoder_input_ids_start
=
torch
.
ones
((
batch_size
,
1
),
dtype
=
torch
.
long
,
device
=
device
)
*
decoder_start_token_id
# no user input -> use decoder_start_token_id as decoder_input_ids
if
decoder_input_ids
is
None
:
decoder_input_ids
=
decoder_input_ids_start
# user input but doesn't start with decoder_start_token_id -> prepend decoder_start_token_id (and adjust
# decoder_attention_mask if provided)
elif
(
decoder_input_ids
[:,
0
]
!=
decoder_start_token_id
).
all
().
item
():
decoder_input_ids
=
torch
.
cat
([
decoder_input_ids_start
,
decoder_input_ids
],
dim
=-
1
)
if
"decoder_attention_mask"
in
model_kwargs
:
decoder_attention_mask
=
model_kwargs
[
"decoder_attention_mask"
]
decoder_attention_mask
=
torch
.
cat
(
(
torch
.
ones_like
(
decoder_attention_mask
)[:,
:
1
],
decoder_attention_mask
),
dim
=-
1
,
)
model_kwargs
[
"decoder_attention_mask"
]
=
decoder_attention_mask
return
decoder_input_ids
,
model_kwargs
def
_get_decoder_start_token_id
(
self
,
decoder_start_token_id
:
int
=
None
,
bos_token_id
:
int
=
None
)
->
int
:
decoder_start_token_id
=
(
...
...
@@ -1289,17 +1315,14 @@ class GenerationMixin:
# 5. Prepare `input_ids` which will be used for auto-regressive generation
if
self
.
config
.
is_encoder_decoder
:
input_ids
=
self
.
_prepare_decoder_input_ids_for_generation
(
batch_size
,
input_ids
,
model_kwargs
=
self
.
_prepare_decoder_input_ids_for_generation
(
batch_size
=
batch_size
,
model_input_name
=
model_input_name
,
model_kwargs
=
model_kwargs
,
decoder_start_token_id
=
generation_config
.
decoder_start_token_id
,
bos_token_id
=
generation_config
.
bos_token_id
,
model_kwargs
=
model_kwargs
,
device
=
inputs_tensor
.
device
,
)
# conditional generation for multi-modal models.
if
"input_ids"
in
model_kwargs
and
model_input_name
==
"pixel_values"
:
input_ids
=
torch
.
cat
([
input_ids
,
model_kwargs
.
pop
(
"input_ids"
)],
dim
=-
1
)
else
:
input_ids
=
inputs_tensor
if
model_input_name
==
"input_ids"
else
model_kwargs
.
pop
(
"input_ids"
)
...
...
src/transformers/models/pix2struct/modeling_pix2struct.py
View file @
9dfd6a4b
...
...
@@ -1776,35 +1776,6 @@ class Pix2StructForConditionalGeneration(Pix2StructPreTrainedModel):
encoder_outputs
=
None
,
**
kwargs
,
):
if
isinstance
(
input_ids
,
torch
.
Tensor
):
# check if the first element of `input_ids` is equal to `input_ids`:
if
(
input_ids
[:,
0
]
!=
self
.
config
.
decoder_start_token_id
).
all
().
item
():
# add `input_ids` as first token to `input_ids`
input_ids
=
torch
.
cat
(
[
torch
.
ones
((
input_ids
.
shape
[
0
],
1
),
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
*
self
.
config
.
decoder_start_token_id
,
input_ids
,
],
dim
=-
1
,
)
if
decoder_attention_mask
is
not
None
:
decoder_attention_mask
=
torch
.
cat
(
[
torch
.
ones
(
(
decoder_attention_mask
.
shape
[
0
],
1
),
dtype
=
torch
.
long
,
device
=
decoder_attention_mask
.
device
,
),
decoder_attention_mask
,
],
dim
=-
1
,
)
elif
input_ids
is
None
:
batch_size
=
flattened_patches
.
shape
[
0
]
input_ids
=
torch
.
LongTensor
([[
self
.
input_ids
]]).
repeat
(
batch_size
,
1
).
to
(
input_ids
.
device
)
if
decoder_attention_mask
is
None
:
decoder_attention_mask
=
torch
.
ones_like
(
input_ids
).
to
(
input_ids
.
device
)
...
...
tests/generation/test_framework_agnostic.py
View file @
9dfd6a4b
...
...
@@ -94,8 +94,8 @@ class GenerationIntegrationTestsMixin:
# Decoder only call
outputs
=
bart_model
.
generate
(
decoder_input_ids
=
input_ids
,
max_new_tokens
=
max_new_tokens
)
#
29
+ 3 new tokens
self
.
assertEqual
(
list
(
outputs
.
shape
),
[
1
,
3
2
])
#
1 BOS + 29 (input length)
+ 3 new tokens
self
.
assertEqual
(
list
(
outputs
.
shape
),
[
1
,
3
3
])
# Encoder decoder call > 20
outputs
=
bart_model
.
generate
(
max_new_tokens
=
max_new_tokens
+
20
)
...
...
@@ -658,3 +658,31 @@ class GenerationIntegrationTestsMixin:
[
token
==
model
.
config
.
pad_token_id
for
token
in
generated_tokens
[
0
][
expectation
:]]
)
self
.
assertTrue
(
unpadded_correct_condition
or
padded_correct_condition
)
def
test_generate_vision2text_conditioning
(
self
):
model_cls
=
self
.
framework_dependent_parameters
[
"AutoModelForVision2Seq"
]
floats_tensor
=
self
.
framework_dependent_parameters
[
"floats_tensor"
]
create_tensor_fn
=
self
.
framework_dependent_parameters
[
"create_tensor_fn"
]
is_pt
=
not
model_cls
.
__name__
.
startswith
(
"TF"
)
pixel_values
=
floats_tensor
((
2
,
3
,
30
,
30
))
conditioning_input
=
create_tensor_fn
([[
10
],
[
10
]])
# this should be the 2nd output token, after the BOS token
model
=
model_cls
.
from_pretrained
(
"hf-internal-testing/tiny-random-VisionEncoderDecoderModel-vit-gpt2"
)
if
is_pt
:
pixel_values
=
pixel_values
.
to
(
torch_device
)
model
=
model
.
to
(
torch_device
)
conditioning_input
=
conditioning_input
.
to
(
torch_device
)
# we can condition on decoder_input_ids (expected decoder input) and input_ids (which we pipe internally as
# decoder_input_ids, if the encoder is not a model with text input)
output_sequences_decoder_input_ids
=
model
.
generate
(
pixel_values
,
max_length
=
5
,
decoder_input_ids
=
conditioning_input
)
output_sequences_input_ids
=
model
.
generate
(
pixel_values
,
max_length
=
5
,
input_ids
=
conditioning_input
)
if
is_pt
:
output_sequences_decoder_input_ids
=
output_sequences_decoder_input_ids
.
cpu
().
numpy
()
output_sequences_input_ids
=
output_sequences_input_ids
.
cpu
().
numpy
()
conditioning_input
=
conditioning_input
.
cpu
().
numpy
()
self
.
assertTrue
(
np
.
array_equal
(
output_sequences_decoder_input_ids
,
output_sequences_input_ids
))
self
.
assertTrue
(
np
.
array_equal
(
output_sequences_decoder_input_ids
[:,
1
:
2
],
conditioning_input
))
tests/generation/test_utils.py
View file @
9dfd6a4b
...
...
@@ -1892,8 +1892,10 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
max_length
=
20
input_ids
=
input_ids
.
expand
(
2
,
-
1
)
model_kwargs
=
bart_model
.
_prepare_encoder_decoder_kwargs_for_generation
(
input_ids
,
{})
input_ids
=
bart_model
.
_prepare_decoder_input_ids_for_generation
(
input_ids
.
shape
[
0
],
input_ids
,
model_kwargs
=
bart_model
.
_prepare_decoder_input_ids_for_generation
(
batch_size
=
input_ids
.
shape
[
0
],
model_input_name
=
bart_model
.
main_input_name
,
model_kwargs
=
model_kwargs
,
decoder_start_token_id
=
bart_model
.
config
.
decoder_start_token_id
,
bos_token_id
=
bart_model
.
config
.
bos_token_id
,
)
...
...
@@ -1919,8 +1921,10 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
max_length
=
20
input_ids
=
input_ids
.
expand
(
2
,
-
1
)
model_kwargs
=
bart_model
.
_prepare_encoder_decoder_kwargs_for_generation
(
input_ids
,
{})
input_ids
=
bart_model
.
_prepare_decoder_input_ids_for_generation
(
input_ids
.
shape
[
0
],
input_ids
,
model_kwargs
=
bart_model
.
_prepare_decoder_input_ids_for_generation
(
batch_size
=
input_ids
.
shape
[
0
],
model_input_name
=
bart_model
.
main_input_name
,
model_kwargs
=
model_kwargs
,
decoder_start_token_id
=
bart_model
.
config
.
decoder_start_token_id
,
bos_token_id
=
bart_model
.
config
.
bos_token_id
,
)
...
...
@@ -1949,8 +1953,10 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
input_ids
=
input_ids
.
expand
(
2
,
-
1
)
model_kwargs
=
bart_model
.
_prepare_encoder_decoder_kwargs_for_generation
(
input_ids
,
{})
input_ids
=
bart_model
.
_prepare_decoder_input_ids_for_generation
(
input_ids
.
shape
[
0
],
input_ids
,
model_kwargs
=
bart_model
.
_prepare_decoder_input_ids_for_generation
(
batch_size
=
input_ids
.
shape
[
0
],
model_input_name
=
bart_model
.
main_input_name
,
model_kwargs
=
model_kwargs
,
decoder_start_token_id
=
bart_model
.
config
.
decoder_start_token_id
,
bos_token_id
=
bart_model
.
config
.
bos_token_id
,
)
...
...
@@ -1982,8 +1988,10 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
input_ids
=
input_ids
.
expand
(
6
,
-
1
)
model_kwargs
=
bart_model
.
_prepare_encoder_decoder_kwargs_for_generation
(
input_ids
,
{})
input_ids
=
bart_model
.
_prepare_decoder_input_ids_for_generation
(
input_ids
.
shape
[
0
],
input_ids
,
model_kwargs
=
bart_model
.
_prepare_decoder_input_ids_for_generation
(
batch_size
=
input_ids
.
shape
[
0
],
model_input_name
=
bart_model
.
main_input_name
,
model_kwargs
=
model_kwargs
,
decoder_start_token_id
=
bart_model
.
config
.
decoder_start_token_id
,
bos_token_id
=
bart_model
.
config
.
bos_token_id
,
)
...
...
@@ -2021,8 +2029,10 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
# Greedy
input_ids
=
input_ids
.
expand
(
6
,
-
1
)
model_kwargs
=
bart_model
.
_prepare_encoder_decoder_kwargs_for_generation
(
input_ids
,
{})
input_ids
=
bart_model
.
_prepare_decoder_input_ids_for_generation
(
input_ids
.
shape
[
0
],
input_ids
,
model_kwargs
=
bart_model
.
_prepare_decoder_input_ids_for_generation
(
batch_size
=
input_ids
.
shape
[
0
],
model_input_name
=
bart_model
.
main_input_name
,
model_kwargs
=
model_kwargs
,
decoder_start_token_id
=
bart_model
.
config
.
decoder_start_token_id
,
bos_token_id
=
bart_model
.
config
.
bos_token_id
,
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment