Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
709dc432
Unverified
Commit
709dc432
authored
Feb 01, 2024
by
fxmarty
Committed by
GitHub
Feb 01, 2024
Browse files
Fix symbolic_trace with kv cache (#28724)
* fix symbolic_trace with kv cache * comment & better test
parent
eb8e7a00
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
135 additions
and
111 deletions
+135
-111
src/transformers/utils/fx.py
src/transformers/utils/fx.py
+16
-4
tests/test_modeling_common.py
tests/test_modeling_common.py
+119
-107
No files found.
src/transformers/utils/fx.py
View file @
709dc432
...
@@ -765,7 +765,7 @@ class HFTracer(Tracer):
...
@@ -765,7 +765,7 @@ class HFTracer(Tracer):
)
)
def
_generate_dummy_input
(
def
_generate_dummy_input
(
self
,
model
:
PreTrainedModel
,
input_name
:
str
,
shape
:
List
[
int
]
self
,
model
:
PreTrainedModel
,
input_name
:
str
,
shape
:
List
[
int
]
,
input_names
:
List
[
str
]
)
->
Dict
[
str
,
torch
.
Tensor
]:
)
->
Dict
[
str
,
torch
.
Tensor
]:
"""Generates dummy input for model inference recording."""
"""Generates dummy input for model inference recording."""
# Retrieving the model class, either from the "class_for_deserialization" attribute if the model was restored
# Retrieving the model class, either from the "class_for_deserialization" attribute if the model was restored
...
@@ -774,6 +774,11 @@ class HFTracer(Tracer):
...
@@ -774,6 +774,11 @@ class HFTracer(Tracer):
device
=
model
.
device
device
=
model
.
device
inputs_dict
=
{}
inputs_dict
=
{}
# when tracing a model with KV cache, we simply need to unsure that the KV cache length is larger than one to
# rightfully pass certain controlflows (Example: https://github.com/huggingface/transformers/blob/5c8d941d66734811d2ef6f57f15b44f7fb7a98c4/src/transformers/modeling_attn_mask_utils.py#L162).
# After tracing, the model can then still be used with arbitrary lengths different than the one used during tracing.
kv_cache_length
=
5
if
input_name
in
[
"labels"
,
"start_positions"
,
"end_positions"
]:
if
input_name
in
[
"labels"
,
"start_positions"
,
"end_positions"
]:
batch_size
=
shape
[
0
]
batch_size
=
shape
[
0
]
if
model_class_name
in
[
if
model_class_name
in
[
...
@@ -883,7 +888,14 @@ class HFTracer(Tracer):
...
@@ -883,7 +888,14 @@ class HFTracer(Tracer):
# Generating big sequence length for audio inputs.
# Generating big sequence length for audio inputs.
seq_length
=
_generate_random_int
(
low
=
10000
,
high
=
20000
)
seq_length
=
_generate_random_int
(
low
=
10000
,
high
=
20000
)
inputs_dict
[
input_name
]
=
torch
.
zeros
(
batch_size
,
seq_length
,
dtype
=
torch
.
float
,
device
=
device
)
inputs_dict
[
input_name
]
=
torch
.
zeros
(
batch_size
,
seq_length
,
dtype
=
torch
.
float
,
device
=
device
)
elif
"mask"
in
input_name
or
"ids"
in
input_name
:
elif
"mask"
in
input_name
:
if
"past_key_values"
in
input_names
:
mask_shape
=
[
shape
[
0
],
shape
[
1
]
+
kv_cache_length
]
else
:
mask_shape
=
shape
inputs_dict
[
input_name
]
=
torch
.
zeros
(
mask_shape
,
dtype
=
torch
.
long
,
device
=
device
)
elif
"ids"
in
input_name
:
inputs_dict
[
input_name
]
=
torch
.
zeros
(
shape
,
dtype
=
torch
.
long
,
device
=
device
)
inputs_dict
[
input_name
]
=
torch
.
zeros
(
shape
,
dtype
=
torch
.
long
,
device
=
device
)
elif
"past_key_values"
in
input_name
:
elif
"past_key_values"
in
input_name
:
if
model
.
config
.
model_type
not
in
_FX_SUPPORTED_MODELS_WITH_KV_CACHE
:
if
model
.
config
.
model_type
not
in
_FX_SUPPORTED_MODELS_WITH_KV_CACHE
:
...
@@ -893,7 +905,7 @@ class HFTracer(Tracer):
...
@@ -893,7 +905,7 @@ class HFTracer(Tracer):
num_heads
=
model
.
config
.
num_attention_heads
num_heads
=
model
.
config
.
num_attention_heads
head_dim
=
model
.
config
.
hidden_size
//
model
.
config
.
num_attention_heads
head_dim
=
model
.
config
.
hidden_size
//
model
.
config
.
num_attention_heads
cache_shape
=
(
shape
[
0
],
num_heads
,
0
,
head_dim
)
cache_shape
=
(
shape
[
0
],
num_heads
,
kv_cache_length
,
head_dim
)
pkv
=
tuple
(
pkv
=
tuple
(
(
(
torch
.
rand
(
cache_shape
,
dtype
=
torch
.
float
,
device
=
device
),
torch
.
rand
(
cache_shape
,
dtype
=
torch
.
float
,
device
=
device
),
...
@@ -1095,7 +1107,7 @@ class HFTracer(Tracer):
...
@@ -1095,7 +1107,7 @@ class HFTracer(Tracer):
if
isinstance
(
root
,
self
.
supported_archs
)
or
type
(
root
).
__qualname__
.
startswith
(
if
isinstance
(
root
,
self
.
supported_archs
)
or
type
(
root
).
__qualname__
.
startswith
(
(
"_deserialize_graph_module"
,
"_CodeOnlyModule"
)
(
"_deserialize_graph_module"
,
"_CodeOnlyModule"
)
):
):
inputs
.
update
(
self
.
_generate_dummy_input
(
root
,
input_name
,
shape
))
inputs
.
update
(
self
.
_generate_dummy_input
(
root
,
input_name
,
shape
,
input_names
=
input_names
))
else
:
else
:
raise
RuntimeError
(
raise
RuntimeError
(
f
"Could not generate input named
{
input_name
}
for because root is not a"
f
"Could not generate input named
{
input_name
}
for because root is not a"
...
...
tests/test_modeling_common.py
View file @
709dc432
...
@@ -1053,132 +1053,144 @@ class ModelTesterMixin:
...
@@ -1053,132 +1053,144 @@ class ModelTesterMixin:
model
.
eval
()
model
.
eval
()
inputs
=
self
.
_prepare_for_class
(
inputs_dict
,
model_class
,
return_labels
=
output_loss
)
inputs
=
self
.
_prepare_for_class
(
inputs_dict
,
model_class
,
return_labels
=
output_loss
)
try
:
# We may want to test several inputs (various shapes, etc.).
if
model
.
config
.
is_encoder_decoder
:
inputs_to_test
=
[
inputs
]
model
.
config
.
use_cache
=
False
# FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward
labels
=
inputs
.
get
(
"labels"
,
None
)
input_names
=
[
"attention_mask"
,
"decoder_attention_mask"
,
"decoder_input_ids"
,
"input_features"
,
"input_ids"
,
"input_values"
,
]
if
labels
is
not
None
:
input_names
.
append
(
"labels"
)
filtered_inputs
=
{
k
:
v
for
(
k
,
v
)
in
inputs
.
items
()
if
k
in
input_names
}
input_names
=
list
(
filtered_inputs
.
keys
())
model_output
=
model
(
**
filtered_inputs
)
if
model
.
config
.
is_encoder_decoder
:
model
.
config
.
use_cache
=
False
# FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward
labels
=
inputs
.
get
(
"labels"
,
None
)
input_names
=
[
"attention_mask"
,
"decoder_attention_mask"
,
"decoder_input_ids"
,
"input_features"
,
"input_ids"
,
"input_values"
,
]
if
labels
is
not
None
:
input_names
.
append
(
"labels"
)
else
:
input_names
=
[
"attention_mask"
,
"bbox"
,
"input_features"
,
"input_ids"
,
"input_values"
,
"pixel_values"
,
"token_type_ids"
,
"visual_feats"
,
"visual_pos"
,
]
traced_model
=
symbolic_trace
(
model
,
input_names
)
labels
=
inputs
.
get
(
"labels"
,
None
)
traced_output
=
traced_model
(
**
filtered_inputs
)
start_positions
=
inputs
.
get
(
"start_positions"
,
None
)
else
:
end_positions
=
inputs
.
get
(
"end_positions"
,
None
)
input_names
=
[
if
labels
is
not
None
:
"attention_mask"
,
input_names
.
append
(
"labels"
)
"bbox"
,
if
start_positions
is
not
None
:
"input_features"
,
input_names
.
append
(
"start_positions"
)
"input_ids"
,
if
end_positions
is
not
None
:
"input_values"
,
input_names
.
append
(
"end_positions"
)
"pixel_values"
,
"token_type_ids"
,
if
model
.
config
.
model_type
in
_FX_SUPPORTED_MODELS_WITH_KV_CACHE
:
"visual_feats"
,
input_names
.
append
(
"past_key_values"
)
"visual_pos"
,
]
# Generally model_tester.prepare_config_and_inputs_for_common seem not to generate past key values inputs.
if
"past_key_values"
not
in
inputs
:
labels
=
inputs
.
get
(
"labels"
,
None
)
batch_size
=
inputs
[
next
(
iter
(
inputs
))].
shape
[
0
]
start_positions
=
inputs
.
get
(
"start_positions"
,
None
)
num_heads
=
model
.
config
.
num_attention_heads
end_positions
=
inputs
.
get
(
"end_positions"
,
None
)
head_dim
=
model
.
config
.
hidden_size
//
model
.
config
.
num_attention_heads
if
labels
is
not
None
:
input_names
.
append
(
"labels"
)
cache_shape
=
(
batch_size
,
num_heads
,
0
,
head_dim
)
if
start_positions
is
not
None
:
empty_pkv
=
tuple
(
input_names
.
append
(
"start_positions"
)
(
if
end_positions
is
not
None
:
torch
.
rand
(
cache_shape
,
dtype
=
torch
.
float
,
device
=
torch_device
),
input_names
.
append
(
"end_positions"
)
torch
.
rand
(
cache_shape
,
dtype
=
torch
.
float
,
device
=
torch_device
),
if
model
.
config
.
model_type
in
_FX_SUPPORTED_MODELS_WITH_KV_CACHE
:
input_names
.
append
(
"past_key_values"
)
# Generally model_tester.prepare_config_and_inputs_for_common seem not to generate past key values inputs.
if
"past_key_values"
not
in
inputs
:
batch_size
=
inputs
[
next
(
iter
(
inputs
))].
shape
[
0
]
num_heads
=
model
.
config
.
num_attention_heads
head_dim
=
model
.
config
.
hidden_size
//
model
.
config
.
num_attention_heads
cache_shape
=
(
batch_size
,
num_heads
,
0
,
head_dim
)
pkv
=
tuple
(
(
torch
.
rand
(
cache_shape
,
dtype
=
torch
.
float
,
device
=
torch_device
),
torch
.
rand
(
cache_shape
,
dtype
=
torch
.
float
,
device
=
torch_device
),
)
for
i
in
range
(
model
.
config
.
num_hidden_layers
)
)
)
for
i
in
range
(
model
.
config
.
num_hidden_layers
)
)
inputs
[
"past_key_values"
]
=
pkv
cache_length
=
9
cache_shape
=
(
batch_size
,
num_heads
,
cache_length
,
head_dim
)
non_empty_pkv
=
tuple
(
(
torch
.
rand
(
cache_shape
,
dtype
=
torch
.
float
,
device
=
torch_device
),
torch
.
rand
(
cache_shape
,
dtype
=
torch
.
float
,
device
=
torch_device
),
)
for
i
in
range
(
model
.
config
.
num_hidden_layers
)
)
filtered_inputs
=
{
k
:
v
for
(
k
,
v
)
in
inputs
.
items
()
if
k
in
input_names
}
inps
=
copy
.
deepcopy
(
inputs_to_test
[
0
])
input_names
=
list
(
filtered_inputs
.
keys
())
if
model
.
__class__
.
__name__
in
set
(
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
.
values
())
and
(
inputs_to_test
[
0
][
"past_key_values"
]
=
empty_pkv
not
hasattr
(
model
.
config
,
"problem_type"
)
or
model
.
config
.
problem_type
is
None
):
model
.
config
.
problem_type
=
"single_label_classification"
traced_model
=
symbolic_trace
(
model
,
input_names
)
inps
[
"past_key_values"
]
=
non_empty_pkv
inputs_to_test
.
append
(
inps
)
with
torch
.
no_grad
():
past_mask
=
torch
.
ones
(
batch_size
,
cache_length
,
device
=
torch_device
,
dtype
=
torch
.
float
)
traced_output
=
traced_model
(
**
filtered_inputs
)
inputs_to_test
[
1
][
"attention_mask"
]
=
torch
.
cat
(
model_output
=
model
(
**
filtered_inputs
)
(
past_mask
,
inputs_to_test
[
1
][
"attention_mask"
]),
dim
=
1
)
except
Exception
as
e
:
for
inps
in
inputs_to_test
:
self
.
fail
(
f
"Couldn't trace module:
{
e
}
"
)
filtered_inputs
=
{
k
:
v
for
(
k
,
v
)
in
inps
.
items
()
if
k
in
input_names
}
input_names
=
list
(
filtered_inputs
.
keys
())
def
flatten_output
(
output
):
if
model
.
__class__
.
__name__
in
set
(
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
.
values
())
and
(
flatten
=
[]
not
hasattr
(
model
.
config
,
"problem_type"
)
or
model
.
config
.
problem_type
is
None
for
x
in
output
:
):
if
isinstance
(
x
,
(
tuple
,
list
)):
model
.
config
.
problem_type
=
"single_label_classification"
flatten
+=
flatten_output
(
x
)
elif
not
isinstance
(
x
,
torch
.
Tensor
):
continue
else
:
flatten
.
append
(
x
)
return
flatten
model_output
=
flatten_output
(
model_output
)
traced_model
=
symbolic_trace
(
model
,
input_names
)
traced_output
=
flatten_output
(
traced_output
)
num_outputs
=
len
(
model_output
)
for
i
in
range
(
num_outputs
):
with
torch
.
no_grad
():
self
.
assertTrue
(
traced_output
=
traced_model
(
**
filtered_inputs
)
torch
.
allclose
(
model_output
[
i
],
traced_output
[
i
]),
model_output
=
model
(
**
filtered_inputs
)
f
"traced
{
i
}
th output doesn't match model
{
i
}
th output for
{
model_class
}
"
,
)
# Test that the model can be serialized and restored properly
def
flatten_output
(
output
):
with
tempfile
.
TemporaryDirectory
()
as
tmp_dir_name
:
flatten
=
[]
pkl_file_name
=
os
.
path
.
join
(
tmp_dir_name
,
"model.pkl"
)
for
x
in
output
:
try
:
if
isinstance
(
x
,
(
tuple
,
list
))
:
with
open
(
pkl_file_name
,
"wb"
)
as
f
:
flatten
+=
flatten_output
(
x
)
pickle
.
dump
(
traced_model
,
f
)
elif
not
isinstance
(
x
,
torch
.
Tensor
):
with
open
(
pkl_file_name
,
"rb"
)
as
f
:
continue
loaded
=
pickle
.
load
(
f
)
else
:
except
Exception
as
e
:
flatten
.
append
(
x
)
self
.
fail
(
f
"Couldn't serialize / deserialize the traced model:
{
e
}
"
)
return
flatten
loaded_output
=
loaded
(
**
filtered_inputs
)
model_output
=
flatten_output
(
model_output
)
loaded_output
=
flatten_output
(
loaded_output
)
traced_output
=
flatten_output
(
traced_output
)
num_outputs
=
len
(
model_output
)
for
i
in
range
(
num_outputs
):
for
i
in
range
(
num_outputs
):
self
.
assertTrue
(
self
.
assertTrue
(
torch
.
allclose
(
model_output
[
i
],
load
ed_output
[
i
]),
torch
.
allclose
(
model_output
[
i
],
trac
ed_output
[
i
]),
f
"
serialized model
{
i
}
th output doesn't match model
{
i
}
th output for
{
model_class
}
"
,
f
"
traced
{
i
}
th output doesn't match model
{
i
}
th output for
{
model_class
}
"
,
)
)
# Avoid memory leak. Without this, each call increase RAM usage by ~20MB.
# Test that the model can be serialized and restored properly
# (Even with this call, there are still memory leak by ~0.04MB)
with
tempfile
.
TemporaryDirectory
()
as
tmp_dir_name
:
self
.
clear_torch_jit_class_registry
()
pkl_file_name
=
os
.
path
.
join
(
tmp_dir_name
,
"model.pkl"
)
try
:
with
open
(
pkl_file_name
,
"wb"
)
as
f
:
pickle
.
dump
(
traced_model
,
f
)
with
open
(
pkl_file_name
,
"rb"
)
as
f
:
loaded
=
pickle
.
load
(
f
)
except
Exception
as
e
:
self
.
fail
(
f
"Couldn't serialize / deserialize the traced model:
{
e
}
"
)
loaded_output
=
loaded
(
**
filtered_inputs
)
loaded_output
=
flatten_output
(
loaded_output
)
for
i
in
range
(
num_outputs
):
self
.
assertTrue
(
torch
.
allclose
(
model_output
[
i
],
loaded_output
[
i
]),
f
"serialized model
{
i
}
th output doesn't match model
{
i
}
th output for
{
model_class
}
"
,
)
# Avoid memory leak. Without this, each call increase RAM usage by ~20MB.
# (Even with this call, there are still memory leak by ~0.04MB)
self
.
clear_torch_jit_class_registry
()
def
test_headmasking
(
self
):
def
test_headmasking
(
self
):
if
not
self
.
test_head_masking
:
if
not
self
.
test_head_masking
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment