Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
a81fe4e1
Unverified
Commit
a81fe4e1
authored
Feb 14, 2023
by
Joao Gante
Committed by
GitHub
Feb 14, 2023
Browse files
Generate: input expansion for any model input (#21624)
parent
13e03e61
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
57 additions
and
66 deletions
+57
-66
src/transformers/generation/tf_utils.py
src/transformers/generation/tf_utils.py
+37
-42
src/transformers/generation/utils.py
src/transformers/generation/utils.py
+10
-14
tests/models/blip_2/test_modeling_blip_2.py
tests/models/blip_2/test_modeling_blip_2.py
+10
-10
No files found.
src/transformers/generation/tf_utils.py
View file @
a81fe4e1
...
@@ -986,17 +986,13 @@ class TFGenerationMixin:
...
@@ -986,17 +986,13 @@ class TFGenerationMixin:
)
)
# 11. broadcast inputs to the desired number of beams
# 11. broadcast inputs to the desired number of beams
input_ids
=
self
.
_expand_to_num_beams
(
input_ids
,
num_beams
=
generation_config
.
num_beams
)
input_ids
,
model_kwargs
=
self
.
_expand_inputs_for_generation
(
input_ids
=
input_ids
,
if
"encoder_outputs"
in
model_kwargs
:
expand_size
=
generation_config
.
num_beams
,
model_kwargs
[
"encoder_outputs"
][
"last_hidden_state"
]
=
self
.
_expand_to_num_beams
(
is_encoder_decoder
=
self
.
config
.
is_encoder_decoder
,
model_kwargs
[
"encoder_outputs"
][
"last_hidden_state"
],
num_beams
=
generation_config
.
num_beams
expand_in_new_axis
=
True
,
)
**
model_kwargs
,
)
if
"attention_mask"
in
model_kwargs
:
model_kwargs
[
"attention_mask"
]
=
self
.
_expand_to_num_beams
(
model_kwargs
[
"attention_mask"
],
num_beams
=
generation_config
.
num_beams
)
# 12. run beam search
# 12. run beam search
return
self
.
beam_search
(
return
self
.
beam_search
(
...
@@ -1025,17 +1021,13 @@ class TFGenerationMixin:
...
@@ -1025,17 +1021,13 @@ class TFGenerationMixin:
logits_warper
=
self
.
_get_logits_warper
(
generation_config
=
generation_config
)
logits_warper
=
self
.
_get_logits_warper
(
generation_config
=
generation_config
)
# 12. broadcast inputs to the desired number of beams
# 12. broadcast inputs to the desired number of beams
input_ids
=
self
.
_expand_to_num_beams
(
input_ids
,
num_beams
=
generation_config
.
num_beams
)
input_ids
,
model_kwargs
=
self
.
_expand_inputs_for_generation
(
input_ids
=
input_ids
,
if
"encoder_outputs"
in
model_kwargs
:
expand_size
=
generation_config
.
num_beams
,
model_kwargs
[
"encoder_outputs"
][
"last_hidden_state"
]
=
self
.
_expand_to_num_beams
(
is_encoder_decoder
=
self
.
config
.
is_encoder_decoder
,
model_kwargs
[
"encoder_outputs"
][
"last_hidden_state"
],
num_beams
=
generation_config
.
num_beams
expand_in_new_axis
=
True
,
)
**
model_kwargs
,
)
if
"attention_mask"
in
model_kwargs
:
model_kwargs
[
"attention_mask"
]
=
self
.
_expand_to_num_beams
(
model_kwargs
[
"attention_mask"
],
num_beams
=
generation_config
.
num_beams
)
# 13. run beam sample (beam search with sampling)
# 13. run beam sample (beam search with sampling)
return
self
.
beam_search
(
return
self
.
beam_search
(
...
@@ -1054,11 +1046,6 @@ class TFGenerationMixin:
...
@@ -1054,11 +1046,6 @@ class TFGenerationMixin:
**
model_kwargs
,
**
model_kwargs
,
)
)
@
staticmethod
def
_expand_to_num_beams
(
tensor
:
tf
.
Tensor
,
num_beams
:
int
)
->
tf
.
Tensor
:
shape
=
shape_list
(
tensor
)
return
tf
.
broadcast_to
(
tensor
[:,
None
],
(
shape
[
0
],
num_beams
)
+
tuple
(
shape
[
1
:]))
def
_prepare_attention_mask_for_generation
(
def
_prepare_attention_mask_for_generation
(
self
,
self
,
inputs
:
tf
.
Tensor
,
inputs
:
tf
.
Tensor
,
...
@@ -1142,29 +1129,37 @@ class TFGenerationMixin:
...
@@ -1142,29 +1129,37 @@ class TFGenerationMixin:
expand_size
:
int
=
1
,
expand_size
:
int
=
1
,
is_encoder_decoder
:
bool
=
False
,
is_encoder_decoder
:
bool
=
False
,
input_ids
:
Optional
[
tf
.
Tensor
]
=
None
,
input_ids
:
Optional
[
tf
.
Tensor
]
=
None
,
expand_in_new_axis
:
bool
=
False
,
**
model_kwargs
,
**
model_kwargs
,
)
->
Tuple
[
tf
.
Tensor
,
Dict
[
str
,
Any
]]:
)
->
Tuple
[
tf
.
Tensor
,
Dict
[
str
,
Any
]]:
"""Expands tensors from [batch_size, ...] to [batch_size * expand_size, ...]"""
"""
if
input_ids
is
not
None
:
Expands tensors from [batch_size, ...] to [batch_size * expand_size, ...] or [batch_size, expand_size, ...],
input_ids
=
tf
.
repeat
(
input_ids
,
expand_size
,
axis
=
0
)
depending on `expand_in_new_axis`. Beam-based approaches expect this function to be used with
`expand_in_new_axis=True`
"""
if
model_kwargs
.
get
(
"token_type_ids"
)
is
not
None
:
def
_expand_tensor
(
tensor
:
tf
.
Tensor
):
model_kwargs
[
"token_type_ids"
]
=
tf
.
repeat
(
model_kwargs
[
"token_type_ids"
],
expand_size
,
axis
=
0
)
if
expand_in_new_axis
:
shape
=
shape_list
(
tensor
)
return
tf
.
broadcast_to
(
tensor
[:,
None
],
(
shape
[
0
],
expand_size
)
+
tuple
(
shape
[
1
:]))
else
:
return
tf
.
repeat
(
tensor
,
expand_size
,
axis
=
0
)
if
model_kwargs
.
get
(
"attention_mask"
)
is
not
None
:
def
_expand_dict_for_generation
(
dict_to_expand
):
model_kwargs
[
"attention_mask"
]
=
tf
.
repeat
(
model_kwargs
[
"attention_mask"
],
expand_size
,
axis
=
0
)
for
key
in
dict_to_expand
:
if
dict_to_expand
[
key
]
is
not
None
and
isinstance
(
dict_to_expand
[
key
],
tf
.
Tensor
):
dict_to_expand
[
key
]
=
_expand_tensor
(
dict_to_expand
[
key
])
return
dict_to_expand
if
model_kwargs
.
get
(
"decoder_attention_mask"
)
is
not
None
:
if
input_ids
is
not
None
:
model_kwargs
[
"decoder_attention_mask"
]
=
tf
.
repeat
(
input_ids
=
_expand_tensor
(
input_ids
)
model_kwargs
[
"decoder_attention_mask"
],
expand_size
,
axis
=
0
)
model_kwargs
=
_expand_dict_for_generation
(
model_kwargs
)
if
is_encoder_decoder
:
if
is_encoder_decoder
:
encoder_outputs
=
model_kwargs
.
get
(
"encoder_outputs"
)
if
model_kwargs
.
get
(
"encoder_outputs"
)
is
None
:
if
encoder_outputs
is
None
:
raise
ValueError
(
"If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined."
)
raise
ValueError
(
"If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined."
)
encoder_outputs
[
"last_hidden_state"
]
=
tf
.
repeat
(
encoder_outputs
.
last_hidden_state
,
expand_size
,
axis
=
0
)
model_kwargs
[
"encoder_outputs"
]
=
_expand_dict_for_generation
(
model_kwargs
[
"encoder_outputs"
])
model_kwargs
[
"encoder_outputs"
]
=
encoder_outputs
return
input_ids
,
model_kwargs
return
input_ids
,
model_kwargs
...
...
src/transformers/generation/utils.py
View file @
a81fe4e1
...
@@ -671,26 +671,22 @@ class GenerationMixin:
...
@@ -671,26 +671,22 @@ class GenerationMixin:
**
model_kwargs
,
**
model_kwargs
,
)
->
Tuple
[
torch
.
LongTensor
,
Dict
[
str
,
Any
]]:
)
->
Tuple
[
torch
.
LongTensor
,
Dict
[
str
,
Any
]]:
"""Expands tensors from [batch_size, ...] to [batch_size * expand_size, ...]"""
"""Expands tensors from [batch_size, ...] to [batch_size * expand_size, ...]"""
def
_expand_dict_for_generation
(
dict_to_expand
):
for
key
in
dict_to_expand
:
if
dict_to_expand
[
key
]
is
not
None
and
isinstance
(
dict_to_expand
[
key
],
torch
.
Tensor
):
dict_to_expand
[
key
]
=
dict_to_expand
[
key
].
repeat_interleave
(
expand_size
,
dim
=
0
)
return
dict_to_expand
if
input_ids
is
not
None
:
if
input_ids
is
not
None
:
input_ids
=
input_ids
.
repeat_interleave
(
expand_size
,
dim
=
0
)
input_ids
=
input_ids
.
repeat_interleave
(
expand_size
,
dim
=
0
)
if
model_kwargs
.
get
(
"token_type_ids"
)
is
not
None
:
model_kwargs
=
_expand_dict_for_generation
(
model_kwargs
)
model_kwargs
[
"token_type_ids"
]
=
model_kwargs
[
"token_type_ids"
].
repeat_interleave
(
expand_size
,
dim
=
0
)
if
model_kwargs
.
get
(
"attention_mask"
)
is
not
None
:
model_kwargs
[
"attention_mask"
]
=
model_kwargs
[
"attention_mask"
].
repeat_interleave
(
expand_size
,
dim
=
0
)
if
is_encoder_decoder
:
if
is_encoder_decoder
:
encoder_outputs
=
model_kwargs
.
get
(
"encoder_outputs"
)
if
model_kwargs
.
get
(
"encoder_outputs"
)
is
None
:
if
encoder_outputs
is
None
:
raise
ValueError
(
"If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined."
)
raise
ValueError
(
"If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined."
)
encoder_outputs
[
"last_hidden_state"
]
=
encoder_outputs
.
last_hidden_state
.
repeat_interleave
(
model_kwargs
[
"encoder_outputs"
]
=
_expand_dict_for_generation
(
model_kwargs
[
"encoder_outputs"
])
expand_size
,
dim
=
0
)
model_kwargs
[
"encoder_outputs"
]
=
encoder_outputs
decoder_attention_mask
=
model_kwargs
.
get
(
"decoder_attention_mask"
)
if
decoder_attention_mask
is
not
None
:
model_kwargs
[
"decoder_attention_mask"
]
=
decoder_attention_mask
.
repeat_interleave
(
expand_size
,
dim
=
0
)
return
input_ids
,
model_kwargs
return
input_ids
,
model_kwargs
...
...
tests/models/blip_2/test_modeling_blip_2.py
View file @
a81fe4e1
...
@@ -797,7 +797,7 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
...
@@ -797,7 +797,7 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
)
)
self
.
assertEqual
(
generated_text
,
"it's not a city, it's a beach"
)
self
.
assertEqual
(
generated_text
,
"it's not a city, it's a beach"
)
def
test_inference_opt_batched
(
self
):
def
test_inference_opt_batched
_beam_search
(
self
):
processor
=
Blip2Processor
.
from_pretrained
(
"Salesforce/blip2-opt-2.7b"
)
processor
=
Blip2Processor
.
from_pretrained
(
"Salesforce/blip2-opt-2.7b"
)
model
=
Blip2ForConditionalGeneration
.
from_pretrained
(
"Salesforce/blip2-opt-2.7b"
).
to
(
torch_device
)
model
=
Blip2ForConditionalGeneration
.
from_pretrained
(
"Salesforce/blip2-opt-2.7b"
).
to
(
torch_device
)
...
@@ -805,11 +805,11 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
...
@@ -805,11 +805,11 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
image
=
prepare_img
()
image
=
prepare_img
()
inputs
=
processor
(
images
=
[
image
,
image
],
return_tensors
=
"pt"
).
to
(
torch_device
)
inputs
=
processor
(
images
=
[
image
,
image
],
return_tensors
=
"pt"
).
to
(
torch_device
)
predictions
=
model
.
generate
(
**
inputs
)
predictions
=
model
.
generate
(
**
inputs
,
num_beams
=
2
)
# Test output
# Test output
(in this case, slightly different from greedy search)
self
.
assertEqual
(
predictions
[
0
].
tolist
(),
[
2
,
102
,
693
,
2828
,
15
,
5
,
4105
,
19
,
10
,
2335
,
50118
])
self
.
assertEqual
(
predictions
[
0
].
tolist
(),
[
2
,
102
,
693
,
2828
,
15
,
5
,
4105
,
19
,
69
,
2335
,
50118
])
self
.
assertEqual
(
predictions
[
1
].
tolist
(),
[
2
,
102
,
693
,
2828
,
15
,
5
,
4105
,
19
,
10
,
2335
,
50118
])
self
.
assertEqual
(
predictions
[
1
].
tolist
(),
[
2
,
102
,
693
,
2828
,
15
,
5
,
4105
,
19
,
69
,
2335
,
50118
])
def
test_inference_t5
(
self
):
def
test_inference_t5
(
self
):
processor
=
Blip2Processor
.
from_pretrained
(
"Salesforce/blip2-flan-t5-xl"
)
processor
=
Blip2Processor
.
from_pretrained
(
"Salesforce/blip2-flan-t5-xl"
)
...
@@ -842,7 +842,7 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
...
@@ -842,7 +842,7 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
)
)
self
.
assertEqual
(
generated_text
,
"san diego"
)
self
.
assertEqual
(
generated_text
,
"san diego"
)
def
test_inference_t5_batched
(
self
):
def
test_inference_t5_batched
_beam_search
(
self
):
processor
=
Blip2Processor
.
from_pretrained
(
"Salesforce/blip2-flan-t5-xl"
)
processor
=
Blip2Processor
.
from_pretrained
(
"Salesforce/blip2-flan-t5-xl"
)
model
=
Blip2ForConditionalGeneration
.
from_pretrained
(
"Salesforce/blip2-flan-t5-xl"
).
to
(
torch_device
)
model
=
Blip2ForConditionalGeneration
.
from_pretrained
(
"Salesforce/blip2-flan-t5-xl"
).
to
(
torch_device
)
...
@@ -850,8 +850,8 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
...
@@ -850,8 +850,8 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
image
=
prepare_img
()
image
=
prepare_img
()
inputs
=
processor
(
images
=
[
image
,
image
],
return_tensors
=
"pt"
).
to
(
torch_device
)
inputs
=
processor
(
images
=
[
image
,
image
],
return_tensors
=
"pt"
).
to
(
torch_device
)
predictions
=
model
.
generate
(
**
inputs
)
predictions
=
model
.
generate
(
**
inputs
,
num_beams
=
2
)
# Test output
# Test output
(in this case, slightly different from greedy search)
self
.
assertEqual
(
predictions
[
0
].
tolist
(),
[
0
,
2335
,
1
556
,
28
,
1782
,
30
,
8
,
2
60
8
,
1
])
self
.
assertEqual
(
predictions
[
0
].
tolist
(),
[
0
,
3
,
9
,
2335
,
1
9
,
3823
,
30
,
8
,
2608
,
2
8
,
1
60
,
1782
,
1
])
self
.
assertEqual
(
predictions
[
1
].
tolist
(),
[
0
,
2335
,
1
556
,
28
,
1782
,
30
,
8
,
2
60
8
,
1
])
self
.
assertEqual
(
predictions
[
1
].
tolist
(),
[
0
,
3
,
9
,
2335
,
1
9
,
3823
,
30
,
8
,
2608
,
2
8
,
1
60
,
1782
,
1
])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment