Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
938cb047
Unverified
Commit
938cb047
authored
Nov 14, 2022
by
Joao Gante
Committed by
GitHub
Nov 14, 2022
Browse files
Generate: add Bloom fixes for contrastive search (#20213)
parent
fda12563
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
72 additions
and
27 deletions
+72
-27
src/transformers/generation/utils.py
src/transformers/generation/utils.py
+19
-6
src/transformers/models/bloom/modeling_bloom.py
src/transformers/models/bloom/modeling_bloom.py
+51
-17
tests/generation/test_utils.py
tests/generation/test_utils.py
+2
-4
No files found.
src/transformers/generation/utils.py
View file @
938cb047
...
...
@@ -672,8 +672,7 @@ class GenerationMixin:
return
input_ids
,
model_kwargs
@
staticmethod
def
_extract_past_from_model_output
(
outputs
:
ModelOutput
):
def
_extract_past_from_model_output
(
self
,
outputs
:
ModelOutput
,
standardize_cache_format
:
bool
=
False
):
past
=
None
if
"past_key_values"
in
outputs
:
past
=
outputs
.
past_key_values
...
...
@@ -681,13 +680,24 @@ class GenerationMixin:
past
=
outputs
.
mems
elif
"past_buckets_states"
in
outputs
:
past
=
outputs
.
past_buckets_states
# Bloom fix: standardizes the cache format when requested
if
standardize_cache_format
and
hasattr
(
self
,
"_convert_to_standard_cache"
):
batch_size
=
outputs
.
logits
.
shape
[
0
]
past
=
self
.
_convert_to_standard_cache
(
past
,
batch_size
=
batch_size
)
return
past
def
_update_model_kwargs_for_generation
(
self
,
outputs
:
ModelOutput
,
model_kwargs
:
Dict
[
str
,
Any
],
is_encoder_decoder
:
bool
=
False
self
,
outputs
:
ModelOutput
,
model_kwargs
:
Dict
[
str
,
Any
],
is_encoder_decoder
:
bool
=
False
,
standardize_cache_format
:
bool
=
False
,
)
->
Dict
[
str
,
Any
]:
# update past
model_kwargs
[
"past"
]
=
self
.
_extract_past_from_model_output
(
outputs
)
model_kwargs
[
"past"
]
=
self
.
_extract_past_from_model_output
(
outputs
,
standardize_cache_format
=
standardize_cache_format
)
# update token_type_ids with last value
if
"token_type_ids"
in
model_kwargs
:
...
...
@@ -1939,7 +1949,10 @@ class GenerationMixin:
logit_for_next_step
=
outputs
.
logits
[:,
-
1
,
:]
model_kwargs
=
self
.
_update_model_kwargs_for_generation
(
outputs
,
model_kwargs
,
is_encoder_decoder
=
self
.
config
.
is_encoder_decoder
outputs
,
model_kwargs
,
is_encoder_decoder
=
self
.
config
.
is_encoder_decoder
,
standardize_cache_format
=
True
,
)
# Expands model inputs top_k times, for batched forward passes (akin to beam search).
...
...
@@ -2001,7 +2014,7 @@ class GenerationMixin:
outputs
=
self
(
**
next_model_inputs
,
return_dict
=
True
,
output_hidden_states
=
True
,
output_attentions
=
output_attentions
)
next_past_key_values
=
self
.
_extract_past_from_model_output
(
outputs
)
next_past_key_values
=
self
.
_extract_past_from_model_output
(
outputs
,
standardize_cache_format
=
True
)
logits
=
outputs
.
logits
[:,
-
1
,
:]
# name is different for encoder-decoder and decoder-only models
...
...
src/transformers/models/bloom/modeling_bloom.py
View file @
938cb047
...
...
@@ -506,6 +506,45 @@ class BloomPreTrainedModel(PreTrainedModel):
if
isinstance
(
module
,
BloomModel
):
module
.
gradient_checkpointing
=
value
@
staticmethod
def
_convert_to_standard_cache
(
past_key_value
:
Tuple
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
batch_size
:
int
)
->
Tuple
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]:
"""
Standardizes the format of the cache so as to match most implementations, i.e. to tuple(tuple([batch_size,
num_heads, ...]))
"""
batch_size_times_num_heads
,
head_dim
,
seq_length
=
past_key_value
[
0
][
0
].
shape
num_heads
=
batch_size_times_num_heads
//
batch_size
# key: [batch_size * num_heads, head_dim, seq_length] -> [batch_size, num_heads, head_dim, seq_length]
# value: [batch_size * num_heads, seq_length, head_dim] -> [batch_size, num_heads, seq_length, head_dim]
return
tuple
(
(
layer_past
[
0
].
view
(
batch_size
,
num_heads
,
head_dim
,
seq_length
),
layer_past
[
1
].
view
(
batch_size
,
num_heads
,
seq_length
,
head_dim
),
)
for
layer_past
in
past_key_value
)
@
staticmethod
def
_convert_to_bloom_cache
(
past_key_value
:
Tuple
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]
)
->
Tuple
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]:
"""
Converts the cache to the format expected by Bloom, i.e. to tuple(tuple([batch_size * num_heads, ...]))
"""
batch_size
,
num_heads
,
head_dim
,
seq_length
=
past_key_value
[
0
][
0
].
shape
batch_size_times_num_heads
=
batch_size
*
num_heads
# key: [batch_size, num_heads, head_dim, seq_length] -> [batch_size * num_heads, head_dim, seq_length]
# value: [batch_size, num_heads, seq_length, head_dim] -> [batch_size * num_heads, seq_length, head_dim]
return
tuple
(
(
layer_past
[
0
].
view
(
batch_size_times_num_heads
,
head_dim
,
seq_length
),
layer_past
[
1
].
view
(
batch_size_times_num_heads
,
seq_length
,
head_dim
),
)
for
layer_past
in
past_key_value
)
BLOOM_START_DOCSTRING
=
r
"""
...
...
@@ -811,6 +850,10 @@ class BloomForCausalLM(BloomPreTrainedModel):
if
past
:
input_ids
=
input_ids
[:,
-
1
].
unsqueeze
(
-
1
)
# the cache may be in the stardard format (e.g. in contrastive search), convert to bloom's format if needed
if
past
[
0
][
0
].
shape
[
0
]
==
input_ids
.
shape
[
0
]:
past
=
self
.
_convert_to_bloom_cache
(
past
)
return
{
"input_ids"
:
input_ids
,
"past_key_values"
:
past
,
...
...
@@ -896,9 +939,8 @@ class BloomForCausalLM(BloomPreTrainedModel):
attentions
=
transformer_outputs
.
attentions
,
)
@
staticmethod
def
_reorder_cache
(
past
:
Tuple
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
...],
beam_idx
:
torch
.
LongTensor
self
,
past
:
Tuple
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
...],
beam_idx
:
torch
.
LongTensor
)
->
Tuple
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
...]:
"""
This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
...
...
@@ -907,28 +949,20 @@ class BloomForCausalLM(BloomPreTrainedModel):
Output shares the same memory storage as `past`.
"""
batch_size_times_num_heads
,
head_dim
,
seq_length
=
past
[
0
][
0
].
shape
batch_size
=
len
(
beam_idx
)
num_heads
=
batch_size_times_num_heads
//
batch_size
standardized_past
=
self
.
_convert_to_standard_cache
(
past
,
batch_size
=
len
(
beam_idx
))
# Get a copy of `beam_idx` on all the devices where we need those indices.
device_to_beam_idx
=
{
past_state
.
device
:
beam_idx
.
to
(
past_state
.
device
)
for
layer_past
in
past
for
past_state
in
layer_past
}
# key: layer_past[0] [batch_size * num_heads, head_dim, seq_length]
# value: layer_past[1] [batch_size * num_heads, seq_length, head_dim]
return
tuple
(
reordered_past
=
tuple
(
(
layer_past
[
0
]
.
view
(
batch_size
,
num_heads
,
head_dim
,
seq_length
)
.
index_select
(
0
,
device_to_beam_idx
[
layer_past
[
0
].
device
])
.
view
(
batch_size_times_num_heads
,
head_dim
,
seq_length
),
layer_past
[
1
]
.
view
(
batch_size
,
num_heads
,
seq_length
,
head_dim
)
.
index_select
(
0
,
device_to_beam_idx
[
layer_past
[
0
].
device
])
.
view
(
batch_size_times_num_heads
,
seq_length
,
head_dim
),
layer_past
[
0
].
index_select
(
0
,
device_to_beam_idx
[
layer_past
[
0
].
device
]),
layer_past
[
1
].
index_select
(
0
,
device_to_beam_idx
[
layer_past
[
0
].
device
]),
)
for
layer_past
in
past
for
layer_past
in
standardized_
past
)
return
self
.
_convert_to_bloom_cache
(
reordered_past
)
@
add_start_docstrings
(
...
...
tests/generation/test_utils.py
View file @
938cb047
...
...
@@ -1411,9 +1411,8 @@ class GenerationTesterMixin:
# check `generate()` and `contrastive_search()` are equal
for
model_class
in
self
.
all_generative_model_classes
:
# TODO: Fix Bloom. Bloom fails because `past` has a different shape.
# won't fix: FSMT and Reformer have a different cache variable type (and format).
if
any
(
model_name
in
model_class
.
__name__
.
lower
()
for
model_name
in
[
"bloom"
,
"fsmt"
,
"reformer"
]):
if
any
(
model_name
in
model_class
.
__name__
.
lower
()
for
model_name
in
[
"fsmt"
,
"reformer"
]):
return
config
,
input_ids
,
attention_mask
,
max_length
=
self
.
_get_input_ids_and_config
()
...
...
@@ -1434,9 +1433,8 @@ class GenerationTesterMixin:
def
test_contrastive_generate_dict_outputs_use_cache
(
self
):
for
model_class
in
self
.
all_generative_model_classes
:
# TODO: Fix Bloom. Bloom fails because `past` has a different shape.
# won't fix: FSMT and Reformer have a different cache variable type (and format).
if
any
(
model_name
in
model_class
.
__name__
.
lower
()
for
model_name
in
[
"bloom"
,
"fsmt"
,
"reformer"
]):
if
any
(
model_name
in
model_class
.
__name__
.
lower
()
for
model_name
in
[
"fsmt"
,
"reformer"
]):
return
# enable cache
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment