Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
709dc432
Unverified
Commit
709dc432
authored
Feb 01, 2024
by
fxmarty
Committed by
GitHub
Feb 01, 2024
Browse files
Fix symbolic_trace with kv cache (#28724)
* fix symbolic_trace with kv cache * comment & better test
parent
eb8e7a00
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
135 additions
and
111 deletions
+135
-111
src/transformers/utils/fx.py
src/transformers/utils/fx.py
+16
-4
tests/test_modeling_common.py
tests/test_modeling_common.py
+119
-107
No files found.
src/transformers/utils/fx.py
View file @
709dc432
...
...
@@ -765,7 +765,7 @@ class HFTracer(Tracer):
)
def
_generate_dummy_input
(
self
,
model
:
PreTrainedModel
,
input_name
:
str
,
shape
:
List
[
int
]
self
,
model
:
PreTrainedModel
,
input_name
:
str
,
shape
:
List
[
int
]
,
input_names
:
List
[
str
]
)
->
Dict
[
str
,
torch
.
Tensor
]:
"""Generates dummy input for model inference recording."""
# Retrieving the model class, either from the "class_for_deserialization" attribute if the model was restored
...
...
@@ -774,6 +774,11 @@ class HFTracer(Tracer):
device
=
model
.
device
inputs_dict
=
{}
# when tracing a model with KV cache, we simply need to unsure that the KV cache length is larger than one to
# rightfully pass certain controlflows (Example: https://github.com/huggingface/transformers/blob/5c8d941d66734811d2ef6f57f15b44f7fb7a98c4/src/transformers/modeling_attn_mask_utils.py#L162).
# After tracing, the model can then still be used with arbitrary lengths different than the one used during tracing.
kv_cache_length
=
5
if
input_name
in
[
"labels"
,
"start_positions"
,
"end_positions"
]:
batch_size
=
shape
[
0
]
if
model_class_name
in
[
...
...
@@ -883,7 +888,14 @@ class HFTracer(Tracer):
# Generating big sequence length for audio inputs.
seq_length
=
_generate_random_int
(
low
=
10000
,
high
=
20000
)
inputs_dict
[
input_name
]
=
torch
.
zeros
(
batch_size
,
seq_length
,
dtype
=
torch
.
float
,
device
=
device
)
elif
"mask"
in
input_name
or
"ids"
in
input_name
:
elif
"mask"
in
input_name
:
if
"past_key_values"
in
input_names
:
mask_shape
=
[
shape
[
0
],
shape
[
1
]
+
kv_cache_length
]
else
:
mask_shape
=
shape
inputs_dict
[
input_name
]
=
torch
.
zeros
(
mask_shape
,
dtype
=
torch
.
long
,
device
=
device
)
elif
"ids"
in
input_name
:
inputs_dict
[
input_name
]
=
torch
.
zeros
(
shape
,
dtype
=
torch
.
long
,
device
=
device
)
elif
"past_key_values"
in
input_name
:
if
model
.
config
.
model_type
not
in
_FX_SUPPORTED_MODELS_WITH_KV_CACHE
:
...
...
@@ -893,7 +905,7 @@ class HFTracer(Tracer):
num_heads
=
model
.
config
.
num_attention_heads
head_dim
=
model
.
config
.
hidden_size
//
model
.
config
.
num_attention_heads
cache_shape
=
(
shape
[
0
],
num_heads
,
0
,
head_dim
)
cache_shape
=
(
shape
[
0
],
num_heads
,
kv_cache_length
,
head_dim
)
pkv
=
tuple
(
(
torch
.
rand
(
cache_shape
,
dtype
=
torch
.
float
,
device
=
device
),
...
...
@@ -1095,7 +1107,7 @@ class HFTracer(Tracer):
if
isinstance
(
root
,
self
.
supported_archs
)
or
type
(
root
).
__qualname__
.
startswith
(
(
"_deserialize_graph_module"
,
"_CodeOnlyModule"
)
):
inputs
.
update
(
self
.
_generate_dummy_input
(
root
,
input_name
,
shape
))
inputs
.
update
(
self
.
_generate_dummy_input
(
root
,
input_name
,
shape
,
input_names
=
input_names
))
else
:
raise
RuntimeError
(
f
"Could not generate input named
{
input_name
}
for because root is not a"
...
...
tests/test_modeling_common.py
View file @
709dc432
...
...
@@ -1053,7 +1053,9 @@ class ModelTesterMixin:
model
.
eval
()
inputs
=
self
.
_prepare_for_class
(
inputs_dict
,
model_class
,
return_labels
=
output_loss
)
try
:
# We may want to test several inputs (various shapes, etc.).
inputs_to_test
=
[
inputs
]
if
model
.
config
.
is_encoder_decoder
:
model
.
config
.
use_cache
=
False
# FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward
labels
=
inputs
.
get
(
"labels"
,
None
)
...
...
@@ -1067,14 +1069,6 @@ class ModelTesterMixin:
]
if
labels
is
not
None
:
input_names
.
append
(
"labels"
)
filtered_inputs
=
{
k
:
v
for
(
k
,
v
)
in
inputs
.
items
()
if
k
in
input_names
}
input_names
=
list
(
filtered_inputs
.
keys
())
model_output
=
model
(
**
filtered_inputs
)
traced_model
=
symbolic_trace
(
model
,
input_names
)
traced_output
=
traced_model
(
**
filtered_inputs
)
else
:
input_names
=
[
"attention_mask"
,
...
...
@@ -1108,7 +1102,17 @@ class ModelTesterMixin:
head_dim
=
model
.
config
.
hidden_size
//
model
.
config
.
num_attention_heads
cache_shape
=
(
batch_size
,
num_heads
,
0
,
head_dim
)
pkv
=
tuple
(
empty_pkv
=
tuple
(
(
torch
.
rand
(
cache_shape
,
dtype
=
torch
.
float
,
device
=
torch_device
),
torch
.
rand
(
cache_shape
,
dtype
=
torch
.
float
,
device
=
torch_device
),
)
for
i
in
range
(
model
.
config
.
num_hidden_layers
)
)
cache_length
=
9
cache_shape
=
(
batch_size
,
num_heads
,
cache_length
,
head_dim
)
non_empty_pkv
=
tuple
(
(
torch
.
rand
(
cache_shape
,
dtype
=
torch
.
float
,
device
=
torch_device
),
torch
.
rand
(
cache_shape
,
dtype
=
torch
.
float
,
device
=
torch_device
),
...
...
@@ -1116,9 +1120,20 @@ class ModelTesterMixin:
for
i
in
range
(
model
.
config
.
num_hidden_layers
)
)
inputs
[
"past_key_values"
]
=
pkv
inps
=
copy
.
deepcopy
(
inputs_to_test
[
0
])
inputs_to_test
[
0
][
"past_key_values"
]
=
empty_pkv
filtered_inputs
=
{
k
:
v
for
(
k
,
v
)
in
inputs
.
items
()
if
k
in
input_names
}
inps
[
"past_key_values"
]
=
non_empty_pkv
inputs_to_test
.
append
(
inps
)
past_mask
=
torch
.
ones
(
batch_size
,
cache_length
,
device
=
torch_device
,
dtype
=
torch
.
float
)
inputs_to_test
[
1
][
"attention_mask"
]
=
torch
.
cat
(
(
past_mask
,
inputs_to_test
[
1
][
"attention_mask"
]),
dim
=
1
)
for
inps
in
inputs_to_test
:
filtered_inputs
=
{
k
:
v
for
(
k
,
v
)
in
inps
.
items
()
if
k
in
input_names
}
input_names
=
list
(
filtered_inputs
.
keys
())
if
model
.
__class__
.
__name__
in
set
(
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
.
values
())
and
(
...
...
@@ -1132,9 +1147,6 @@ class ModelTesterMixin:
traced_output
=
traced_model
(
**
filtered_inputs
)
model_output
=
model
(
**
filtered_inputs
)
except
Exception
as
e
:
self
.
fail
(
f
"Couldn't trace module:
{
e
}
"
)
def
flatten_output
(
output
):
flatten
=
[]
for
x
in
output
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment