"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "e468192e2f0099faf3948551cd618eb2942cae2c"
Unverified Commit 709dc432 authored by fxmarty's avatar fxmarty Committed by GitHub
Browse files

Fix symbolic_trace with kv cache (#28724)

* fix symbolic_trace with kv cache

* comment & better test
parent eb8e7a00
...@@ -765,7 +765,7 @@ class HFTracer(Tracer): ...@@ -765,7 +765,7 @@ class HFTracer(Tracer):
) )
def _generate_dummy_input( def _generate_dummy_input(
self, model: PreTrainedModel, input_name: str, shape: List[int] self, model: PreTrainedModel, input_name: str, shape: List[int], input_names: List[str]
) -> Dict[str, torch.Tensor]: ) -> Dict[str, torch.Tensor]:
"""Generates dummy input for model inference recording.""" """Generates dummy input for model inference recording."""
# Retrieving the model class, either from the "class_for_deserialization" attribute if the model was restored # Retrieving the model class, either from the "class_for_deserialization" attribute if the model was restored
...@@ -774,6 +774,11 @@ class HFTracer(Tracer): ...@@ -774,6 +774,11 @@ class HFTracer(Tracer):
device = model.device device = model.device
inputs_dict = {} inputs_dict = {}
# when tracing a model with KV cache, we simply need to unsure that the KV cache length is larger than one to
# rightfully pass certain controlflows (Example: https://github.com/huggingface/transformers/blob/5c8d941d66734811d2ef6f57f15b44f7fb7a98c4/src/transformers/modeling_attn_mask_utils.py#L162).
# After tracing, the model can then still be used with arbitrary lengths different than the one used during tracing.
kv_cache_length = 5
if input_name in ["labels", "start_positions", "end_positions"]: if input_name in ["labels", "start_positions", "end_positions"]:
batch_size = shape[0] batch_size = shape[0]
if model_class_name in [ if model_class_name in [
...@@ -883,7 +888,14 @@ class HFTracer(Tracer): ...@@ -883,7 +888,14 @@ class HFTracer(Tracer):
# Generating big sequence length for audio inputs. # Generating big sequence length for audio inputs.
seq_length = _generate_random_int(low=10000, high=20000) seq_length = _generate_random_int(low=10000, high=20000)
inputs_dict[input_name] = torch.zeros(batch_size, seq_length, dtype=torch.float, device=device) inputs_dict[input_name] = torch.zeros(batch_size, seq_length, dtype=torch.float, device=device)
elif "mask" in input_name or "ids" in input_name: elif "mask" in input_name:
if "past_key_values" in input_names:
mask_shape = [shape[0], shape[1] + kv_cache_length]
else:
mask_shape = shape
inputs_dict[input_name] = torch.zeros(mask_shape, dtype=torch.long, device=device)
elif "ids" in input_name:
inputs_dict[input_name] = torch.zeros(shape, dtype=torch.long, device=device) inputs_dict[input_name] = torch.zeros(shape, dtype=torch.long, device=device)
elif "past_key_values" in input_name: elif "past_key_values" in input_name:
if model.config.model_type not in _FX_SUPPORTED_MODELS_WITH_KV_CACHE: if model.config.model_type not in _FX_SUPPORTED_MODELS_WITH_KV_CACHE:
...@@ -893,7 +905,7 @@ class HFTracer(Tracer): ...@@ -893,7 +905,7 @@ class HFTracer(Tracer):
num_heads = model.config.num_attention_heads num_heads = model.config.num_attention_heads
head_dim = model.config.hidden_size // model.config.num_attention_heads head_dim = model.config.hidden_size // model.config.num_attention_heads
cache_shape = (shape[0], num_heads, 0, head_dim) cache_shape = (shape[0], num_heads, kv_cache_length, head_dim)
pkv = tuple( pkv = tuple(
( (
torch.rand(cache_shape, dtype=torch.float, device=device), torch.rand(cache_shape, dtype=torch.float, device=device),
...@@ -1095,7 +1107,7 @@ class HFTracer(Tracer): ...@@ -1095,7 +1107,7 @@ class HFTracer(Tracer):
if isinstance(root, self.supported_archs) or type(root).__qualname__.startswith( if isinstance(root, self.supported_archs) or type(root).__qualname__.startswith(
("_deserialize_graph_module", "_CodeOnlyModule") ("_deserialize_graph_module", "_CodeOnlyModule")
): ):
inputs.update(self._generate_dummy_input(root, input_name, shape)) inputs.update(self._generate_dummy_input(root, input_name, shape, input_names=input_names))
else: else:
raise RuntimeError( raise RuntimeError(
f"Could not generate input named {input_name} for because root is not a" f"Could not generate input named {input_name} for because root is not a"
......
...@@ -1053,132 +1053,144 @@ class ModelTesterMixin: ...@@ -1053,132 +1053,144 @@ class ModelTesterMixin:
model.eval() model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=output_loss) inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=output_loss)
try: # We may want to test several inputs (various shapes, etc.).
if model.config.is_encoder_decoder: inputs_to_test = [inputs]
model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward
labels = inputs.get("labels", None)
input_names = [
"attention_mask",
"decoder_attention_mask",
"decoder_input_ids",
"input_features",
"input_ids",
"input_values",
]
if labels is not None:
input_names.append("labels")
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
input_names = list(filtered_inputs.keys())
model_output = model(**filtered_inputs) if model.config.is_encoder_decoder:
model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward
labels = inputs.get("labels", None)
input_names = [
"attention_mask",
"decoder_attention_mask",
"decoder_input_ids",
"input_features",
"input_ids",
"input_values",
]
if labels is not None:
input_names.append("labels")
else:
input_names = [
"attention_mask",
"bbox",
"input_features",
"input_ids",
"input_values",
"pixel_values",
"token_type_ids",
"visual_feats",
"visual_pos",
]
traced_model = symbolic_trace(model, input_names) labels = inputs.get("labels", None)
traced_output = traced_model(**filtered_inputs) start_positions = inputs.get("start_positions", None)
else: end_positions = inputs.get("end_positions", None)
input_names = [ if labels is not None:
"attention_mask", input_names.append("labels")
"bbox", if start_positions is not None:
"input_features", input_names.append("start_positions")
"input_ids", if end_positions is not None:
"input_values", input_names.append("end_positions")
"pixel_values",
"token_type_ids", if model.config.model_type in _FX_SUPPORTED_MODELS_WITH_KV_CACHE:
"visual_feats", input_names.append("past_key_values")
"visual_pos",
] # Generally model_tester.prepare_config_and_inputs_for_common seem not to generate past key values inputs.
if "past_key_values" not in inputs:
labels = inputs.get("labels", None) batch_size = inputs[next(iter(inputs))].shape[0]
start_positions = inputs.get("start_positions", None) num_heads = model.config.num_attention_heads
end_positions = inputs.get("end_positions", None) head_dim = model.config.hidden_size // model.config.num_attention_heads
if labels is not None:
input_names.append("labels") cache_shape = (batch_size, num_heads, 0, head_dim)
if start_positions is not None: empty_pkv = tuple(
input_names.append("start_positions") (
if end_positions is not None: torch.rand(cache_shape, dtype=torch.float, device=torch_device),
input_names.append("end_positions") torch.rand(cache_shape, dtype=torch.float, device=torch_device),
if model.config.model_type in _FX_SUPPORTED_MODELS_WITH_KV_CACHE:
input_names.append("past_key_values")
# Generally model_tester.prepare_config_and_inputs_for_common seem not to generate past key values inputs.
if "past_key_values" not in inputs:
batch_size = inputs[next(iter(inputs))].shape[0]
num_heads = model.config.num_attention_heads
head_dim = model.config.hidden_size // model.config.num_attention_heads
cache_shape = (batch_size, num_heads, 0, head_dim)
pkv = tuple(
(
torch.rand(cache_shape, dtype=torch.float, device=torch_device),
torch.rand(cache_shape, dtype=torch.float, device=torch_device),
)
for i in range(model.config.num_hidden_layers)
) )
for i in range(model.config.num_hidden_layers)
)
inputs["past_key_values"] = pkv cache_length = 9
cache_shape = (batch_size, num_heads, cache_length, head_dim)
non_empty_pkv = tuple(
(
torch.rand(cache_shape, dtype=torch.float, device=torch_device),
torch.rand(cache_shape, dtype=torch.float, device=torch_device),
)
for i in range(model.config.num_hidden_layers)
)
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names} inps = copy.deepcopy(inputs_to_test[0])
input_names = list(filtered_inputs.keys())
if model.__class__.__name__ in set(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES.values()) and ( inputs_to_test[0]["past_key_values"] = empty_pkv
not hasattr(model.config, "problem_type") or model.config.problem_type is None
):
model.config.problem_type = "single_label_classification"
traced_model = symbolic_trace(model, input_names) inps["past_key_values"] = non_empty_pkv
inputs_to_test.append(inps)
with torch.no_grad(): past_mask = torch.ones(batch_size, cache_length, device=torch_device, dtype=torch.float)
traced_output = traced_model(**filtered_inputs) inputs_to_test[1]["attention_mask"] = torch.cat(
model_output = model(**filtered_inputs) (past_mask, inputs_to_test[1]["attention_mask"]), dim=1
)
except Exception as e: for inps in inputs_to_test:
self.fail(f"Couldn't trace module: {e}") filtered_inputs = {k: v for (k, v) in inps.items() if k in input_names}
input_names = list(filtered_inputs.keys())
def flatten_output(output): if model.__class__.__name__ in set(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES.values()) and (
flatten = [] not hasattr(model.config, "problem_type") or model.config.problem_type is None
for x in output: ):
if isinstance(x, (tuple, list)): model.config.problem_type = "single_label_classification"
flatten += flatten_output(x)
elif not isinstance(x, torch.Tensor):
continue
else:
flatten.append(x)
return flatten
model_output = flatten_output(model_output) traced_model = symbolic_trace(model, input_names)
traced_output = flatten_output(traced_output)
num_outputs = len(model_output)
for i in range(num_outputs): with torch.no_grad():
self.assertTrue( traced_output = traced_model(**filtered_inputs)
torch.allclose(model_output[i], traced_output[i]), model_output = model(**filtered_inputs)
f"traced {i}th output doesn't match model {i}th output for {model_class}",
)
# Test that the model can be serialized and restored properly def flatten_output(output):
with tempfile.TemporaryDirectory() as tmp_dir_name: flatten = []
pkl_file_name = os.path.join(tmp_dir_name, "model.pkl") for x in output:
try: if isinstance(x, (tuple, list)):
with open(pkl_file_name, "wb") as f: flatten += flatten_output(x)
pickle.dump(traced_model, f) elif not isinstance(x, torch.Tensor):
with open(pkl_file_name, "rb") as f: continue
loaded = pickle.load(f) else:
except Exception as e: flatten.append(x)
self.fail(f"Couldn't serialize / deserialize the traced model: {e}") return flatten
loaded_output = loaded(**filtered_inputs) model_output = flatten_output(model_output)
loaded_output = flatten_output(loaded_output) traced_output = flatten_output(traced_output)
num_outputs = len(model_output)
for i in range(num_outputs): for i in range(num_outputs):
self.assertTrue( self.assertTrue(
torch.allclose(model_output[i], loaded_output[i]), torch.allclose(model_output[i], traced_output[i]),
f"serialized model {i}th output doesn't match model {i}th output for {model_class}", f"traced {i}th output doesn't match model {i}th output for {model_class}",
) )
# Avoid memory leak. Without this, each call increase RAM usage by ~20MB. # Test that the model can be serialized and restored properly
# (Even with this call, there are still memory leak by ~0.04MB) with tempfile.TemporaryDirectory() as tmp_dir_name:
self.clear_torch_jit_class_registry() pkl_file_name = os.path.join(tmp_dir_name, "model.pkl")
try:
with open(pkl_file_name, "wb") as f:
pickle.dump(traced_model, f)
with open(pkl_file_name, "rb") as f:
loaded = pickle.load(f)
except Exception as e:
self.fail(f"Couldn't serialize / deserialize the traced model: {e}")
loaded_output = loaded(**filtered_inputs)
loaded_output = flatten_output(loaded_output)
for i in range(num_outputs):
self.assertTrue(
torch.allclose(model_output[i], loaded_output[i]),
f"serialized model {i}th output doesn't match model {i}th output for {model_class}",
)
# Avoid memory leak. Without this, each call increase RAM usage by ~20MB.
# (Even with this call, there are still memory leak by ~0.04MB)
self.clear_torch_jit_class_registry()
def test_headmasking(self): def test_headmasking(self):
if not self.test_head_masking: if not self.test_head_masking:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment