Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
a81fe4e1
Unverified
Commit
a81fe4e1
authored
Feb 14, 2023
by
Joao Gante
Committed by
GitHub
Feb 14, 2023
Browse files
Generate: input expansion for any model input (#21624)
parent
13e03e61
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
57 additions
and
66 deletions
+57
-66
src/transformers/generation/tf_utils.py
src/transformers/generation/tf_utils.py
+37
-42
src/transformers/generation/utils.py
src/transformers/generation/utils.py
+10
-14
tests/models/blip_2/test_modeling_blip_2.py
tests/models/blip_2/test_modeling_blip_2.py
+10
-10
No files found.
src/transformers/generation/tf_utils.py
View file @
a81fe4e1
...
@@ -986,17 +986,13 @@ class TFGenerationMixin:
...
@@ -986,17 +986,13 @@ class TFGenerationMixin:
)
)
# 11. broadcast inputs to the desired number of beams
# 11. broadcast inputs to the desired number of beams
input_ids
=
self
.
_expand_to_num_beams
(
input_ids
,
num_beams
=
generation_config
.
num_beams
)
input_ids
,
model_kwargs
=
self
.
_expand_inputs_for_generation
(
input_ids
=
input_ids
,
if
"encoder_outputs"
in
model_kwargs
:
expand_size
=
generation_config
.
num_beams
,
model_kwargs
[
"encoder_outputs"
][
"last_hidden_state"
]
=
self
.
_expand_to_num_beams
(
is_encoder_decoder
=
self
.
config
.
is_encoder_decoder
,
model_kwargs
[
"encoder_outputs"
][
"last_hidden_state"
],
num_beams
=
generation_config
.
num_beams
expand_in_new_axis
=
True
,
)
**
model_kwargs
,
)
if
"attention_mask"
in
model_kwargs
:
model_kwargs
[
"attention_mask"
]
=
self
.
_expand_to_num_beams
(
model_kwargs
[
"attention_mask"
],
num_beams
=
generation_config
.
num_beams
)
# 12. run beam search
# 12. run beam search
return
self
.
beam_search
(
return
self
.
beam_search
(
...
@@ -1025,17 +1021,13 @@ class TFGenerationMixin:
...
@@ -1025,17 +1021,13 @@ class TFGenerationMixin:
logits_warper
=
self
.
_get_logits_warper
(
generation_config
=
generation_config
)
logits_warper
=
self
.
_get_logits_warper
(
generation_config
=
generation_config
)
# 12. broadcast inputs to the desired number of beams
# 12. broadcast inputs to the desired number of beams
input_ids
=
self
.
_expand_to_num_beams
(
input_ids
,
num_beams
=
generation_config
.
num_beams
)
input_ids
,
model_kwargs
=
self
.
_expand_inputs_for_generation
(
input_ids
=
input_ids
,
if
"encoder_outputs"
in
model_kwargs
:
expand_size
=
generation_config
.
num_beams
,
model_kwargs
[
"encoder_outputs"
][
"last_hidden_state"
]
=
self
.
_expand_to_num_beams
(
is_encoder_decoder
=
self
.
config
.
is_encoder_decoder
,
model_kwargs
[
"encoder_outputs"
][
"last_hidden_state"
],
num_beams
=
generation_config
.
num_beams
expand_in_new_axis
=
True
,
)
**
model_kwargs
,
)
if
"attention_mask"
in
model_kwargs
:
model_kwargs
[
"attention_mask"
]
=
self
.
_expand_to_num_beams
(
model_kwargs
[
"attention_mask"
],
num_beams
=
generation_config
.
num_beams
)
# 13. run beam sample (beam search with sampling)
# 13. run beam sample (beam search with sampling)
return
self
.
beam_search
(
return
self
.
beam_search
(
...
@@ -1054,11 +1046,6 @@ class TFGenerationMixin:
...
@@ -1054,11 +1046,6 @@ class TFGenerationMixin:
**
model_kwargs
,
**
model_kwargs
,
)
)
@
staticmethod
def
_expand_to_num_beams
(
tensor
:
tf
.
Tensor
,
num_beams
:
int
)
->
tf
.
Tensor
:
shape
=
shape_list
(
tensor
)
return
tf
.
broadcast_to
(
tensor
[:,
None
],
(
shape
[
0
],
num_beams
)
+
tuple
(
shape
[
1
:]))
def
_prepare_attention_mask_for_generation
(
def
_prepare_attention_mask_for_generation
(
self
,
self
,
inputs
:
tf
.
Tensor
,
inputs
:
tf
.
Tensor
,
...
@@ -1142,29 +1129,37 @@ class TFGenerationMixin:
...
@@ -1142,29 +1129,37 @@ class TFGenerationMixin:
expand_size
:
int
=
1
,
expand_size
:
int
=
1
,
is_encoder_decoder
:
bool
=
False
,
is_encoder_decoder
:
bool
=
False
,
input_ids
:
Optional
[
tf
.
Tensor
]
=
None
,
input_ids
:
Optional
[
tf
.
Tensor
]
=
None
,
expand_in_new_axis
:
bool
=
False
,
**
model_kwargs
,
**
model_kwargs
,
)
->
Tuple
[
tf
.
Tensor
,
Dict
[
str
,
Any
]]:
)
->
Tuple
[
tf
.
Tensor
,
Dict
[
str
,
Any
]]:
"""Expands tensors from [batch_size, ...] to [batch_size * expand_size, ...]"""
"""
if
input_ids
is
not
None
:
Expands tensors from [batch_size, ...] to [batch_size * expand_size, ...] or [batch_size, expand_size, ...],
input_ids
=
tf
.
repeat
(
input_ids
,
expand_size
,
axis
=
0
)
depending on `expand_in_new_axis`. Beam-based approaches expect this function to be used with
`expand_in_new_axis=True`
"""
if
model_kwargs
.
get
(
"token_type_ids"
)
is
not
None
:
def
_expand_tensor
(
tensor
:
tf
.
Tensor
):
model_kwargs
[
"token_type_ids"
]
=
tf
.
repeat
(
model_kwargs
[
"token_type_ids"
],
expand_size
,
axis
=
0
)
if
expand_in_new_axis
:
shape
=
shape_list
(
tensor
)
return
tf
.
broadcast_to
(
tensor
[:,
None
],
(
shape
[
0
],
expand_size
)
+
tuple
(
shape
[
1
:]))
else
:
return
tf
.
repeat
(
tensor
,
expand_size
,
axis
=
0
)
if
model_kwargs
.
get
(
"attention_mask"
)
is
not
None
:
def
_expand_dict_for_generation
(
dict_to_expand
):
model_kwargs
[
"attention_mask"
]
=
tf
.
repeat
(
model_kwargs
[
"attention_mask"
],
expand_size
,
axis
=
0
)
for
key
in
dict_to_expand
:
if
dict_to_expand
[
key
]
is
not
None
and
isinstance
(
dict_to_expand
[
key
],
tf
.
Tensor
):
dict_to_expand
[
key
]
=
_expand_tensor
(
dict_to_expand
[
key
])
return
dict_to_expand
if
model_kwargs
.
get
(
"decoder_attention_mask"
)
is
not
None
:
if
input_ids
is
not
None
:
model_kwargs
[
"decoder_attention_mask"
]
=
tf
.
repeat
(
input_ids
=
_expand_tensor
(
input_ids
)
model_kwargs
[
"decoder_attention_mask"
],
expand_size
,
axis
=
0
)
model_kwargs
=
_expand_dict_for_generation
(
model_kwargs
)
if
is_encoder_decoder
:
if
is_encoder_decoder
:
encoder_outputs
=
model_kwargs
.
get
(
"encoder_outputs"
)
if
model_kwargs
.
get
(
"encoder_outputs"
)
is
None
:
if
encoder_outputs
is
None
:
raise
ValueError
(
"If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined."
)
raise
ValueError
(
"If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined."
)
encoder_outputs
[
"last_hidden_state"
]
=
tf
.
repeat
(
encoder_outputs
.
last_hidden_state
,
expand_size
,
axis
=
0
)
model_kwargs
[
"encoder_outputs"
]
=
_expand_dict_for_generation
(
model_kwargs
[
"encoder_outputs"
])
model_kwargs
[
"encoder_outputs"
]
=
encoder_outputs
return
input_ids
,
model_kwargs
return
input_ids
,
model_kwargs
...
...
src/transformers/generation/utils.py
View file @
a81fe4e1
...
@@ -671,26 +671,22 @@ class GenerationMixin:
...
@@ -671,26 +671,22 @@ class GenerationMixin:
**
model_kwargs
,
**
model_kwargs
,
)
->
Tuple
[
torch
.
LongTensor
,
Dict
[
str
,
Any
]]:
)
->
Tuple
[
torch
.
LongTensor
,
Dict
[
str
,
Any
]]:
"""Expands tensors from [batch_size, ...] to [batch_size * expand_size, ...]"""
"""Expands tensors from [batch_size, ...] to [batch_size * expand_size, ...]"""
def
_expand_dict_for_generation
(
dict_to_expand
):
for
key
in
dict_to_expand
:
if
dict_to_expand
[
key
]
is
not
None
and
isinstance
(
dict_to_expand
[
key
],
torch
.
Tensor
):
dict_to_expand
[
key
]
=
dict_to_expand
[
key
].
repeat_interleave
(
expand_size
,
dim
=
0
)
return
dict_to_expand
if
input_ids
is
not
None
:
if
input_ids
is
not
None
:
input_ids
=
input_ids
.
repeat_interleave
(
expand_size
,
dim
=
0
)
input_ids
=
input_ids
.
repeat_interleave
(
expand_size
,
dim
=
0
)
if
model_kwargs
.
get
(
"token_type_ids"
)
is
not
None
:
model_kwargs
=
_expand_dict_for_generation
(
model_kwargs
)
model_kwargs
[
"token_type_ids"
]
=
model_kwargs
[
"token_type_ids"
].
repeat_interleave
(
expand_size
,
dim
=
0
)
if
model_kwargs
.
get
(
"attention_mask"
)
is
not
None
:
model_kwargs
[
"attention_mask"
]
=
model_kwargs
[
"attention_mask"
].
repeat_interleave
(
expand_size
,
dim
=
0
)
if
is_encoder_decoder
:
if
is_encoder_decoder
:
encoder_outputs
=
model_kwargs
.
get
(
"encoder_outputs"
)
if
model_kwargs
.
get
(
"encoder_outputs"
)
is
None
:
if
encoder_outputs
is
None
:
raise
ValueError
(
"If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined."
)
raise
ValueError
(
"If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined."
)
encoder_outputs
[
"last_hidden_state"
]
=
encoder_outputs
.
last_hidden_state
.
repeat_interleave
(
model_kwargs
[
"encoder_outputs"
]
=
_expand_dict_for_generation
(
model_kwargs
[
"encoder_outputs"
])
expand_size
,
dim
=
0
)
model_kwargs
[
"encoder_outputs"
]
=
encoder_outputs
decoder_attention_mask
=
model_kwargs
.
get
(
"decoder_attention_mask"
)
if
decoder_attention_mask
is
not
None
:
model_kwargs
[
"decoder_attention_mask"
]
=
decoder_attention_mask
.
repeat_interleave
(
expand_size
,
dim
=
0
)
return
input_ids
,
model_kwargs
return
input_ids
,
model_kwargs
...
...
tests/models/blip_2/test_modeling_blip_2.py
View file @
a81fe4e1
...
@@ -797,7 +797,7 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
...
@@ -797,7 +797,7 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
)
)
self
.
assertEqual
(
generated_text
,
"it's not a city, it's a beach"
)
self
.
assertEqual
(
generated_text
,
"it's not a city, it's a beach"
)
def
test_inference_opt_batched
(
self
):
def
test_inference_opt_batched
_beam_search
(
self
):
processor
=
Blip2Processor
.
from_pretrained
(
"Salesforce/blip2-opt-2.7b"
)
processor
=
Blip2Processor
.
from_pretrained
(
"Salesforce/blip2-opt-2.7b"
)
model
=
Blip2ForConditionalGeneration
.
from_pretrained
(
"Salesforce/blip2-opt-2.7b"
).
to
(
torch_device
)
model
=
Blip2ForConditionalGeneration
.
from_pretrained
(
"Salesforce/blip2-opt-2.7b"
).
to
(
torch_device
)
...
@@ -805,11 +805,11 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
...
@@ -805,11 +805,11 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
image
=
prepare_img
()
image
=
prepare_img
()
inputs
=
processor
(
images
=
[
image
,
image
],
return_tensors
=
"pt"
).
to
(
torch_device
)
inputs
=
processor
(
images
=
[
image
,
image
],
return_tensors
=
"pt"
).
to
(
torch_device
)
predictions
=
model
.
generate
(
**
inputs
)
predictions
=
model
.
generate
(
**
inputs
,
num_beams
=
2
)
# Test output
# Test output
(in this case, slightly different from greedy search)
self
.
assertEqual
(
predictions
[
0
].
tolist
(),
[
2
,
102
,
693
,
2828
,
15
,
5
,
4105
,
19
,
10
,
2335
,
50118
])
self
.
assertEqual
(
predictions
[
0
].
tolist
(),
[
2
,
102
,
693
,
2828
,
15
,
5
,
4105
,
19
,
69
,
2335
,
50118
])
self
.
assertEqual
(
predictions
[
1
].
tolist
(),
[
2
,
102
,
693
,
2828
,
15
,
5
,
4105
,
19
,
10
,
2335
,
50118
])
self
.
assertEqual
(
predictions
[
1
].
tolist
(),
[
2
,
102
,
693
,
2828
,
15
,
5
,
4105
,
19
,
69
,
2335
,
50118
])
def
test_inference_t5
(
self
):
def
test_inference_t5
(
self
):
processor
=
Blip2Processor
.
from_pretrained
(
"Salesforce/blip2-flan-t5-xl"
)
processor
=
Blip2Processor
.
from_pretrained
(
"Salesforce/blip2-flan-t5-xl"
)
...
@@ -842,7 +842,7 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
...
@@ -842,7 +842,7 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
)
)
self
.
assertEqual
(
generated_text
,
"san diego"
)
self
.
assertEqual
(
generated_text
,
"san diego"
)
def
test_inference_t5_batched
(
self
):
def
test_inference_t5_batched
_beam_search
(
self
):
processor
=
Blip2Processor
.
from_pretrained
(
"Salesforce/blip2-flan-t5-xl"
)
processor
=
Blip2Processor
.
from_pretrained
(
"Salesforce/blip2-flan-t5-xl"
)
model
=
Blip2ForConditionalGeneration
.
from_pretrained
(
"Salesforce/blip2-flan-t5-xl"
).
to
(
torch_device
)
model
=
Blip2ForConditionalGeneration
.
from_pretrained
(
"Salesforce/blip2-flan-t5-xl"
).
to
(
torch_device
)
...
@@ -850,8 +850,8 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
...
@@ -850,8 +850,8 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
image
=
prepare_img
()
image
=
prepare_img
()
inputs
=
processor
(
images
=
[
image
,
image
],
return_tensors
=
"pt"
).
to
(
torch_device
)
inputs
=
processor
(
images
=
[
image
,
image
],
return_tensors
=
"pt"
).
to
(
torch_device
)
predictions
=
model
.
generate
(
**
inputs
)
predictions
=
model
.
generate
(
**
inputs
,
num_beams
=
2
)
# Test output
# Test output
(in this case, slightly different from greedy search)
self
.
assertEqual
(
predictions
[
0
].
tolist
(),
[
0
,
2335
,
1
556
,
28
,
1782
,
30
,
8
,
2
60
8
,
1
])
self
.
assertEqual
(
predictions
[
0
].
tolist
(),
[
0
,
3
,
9
,
2335
,
1
9
,
3823
,
30
,
8
,
2608
,
2
8
,
1
60
,
1782
,
1
])
self
.
assertEqual
(
predictions
[
1
].
tolist
(),
[
0
,
2335
,
1
556
,
28
,
1782
,
30
,
8
,
2
60
8
,
1
])
self
.
assertEqual
(
predictions
[
1
].
tolist
(),
[
0
,
3
,
9
,
2335
,
1
9
,
3823
,
30
,
8
,
2608
,
2
8
,
1
60
,
1782
,
1
])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment