Unverified Commit ba743700 authored by fxmarty's avatar fxmarty Committed by GitHub
Browse files

transformers.fx.symbolic_trace supports inputs_embeds (#31574)



* symbolic trace supports inputs_embeds

* fix test?

* Update tests/test_modeling_common.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

---------
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent e5ca9b05
...@@ -995,6 +995,13 @@ class HFTracer(Tracer): ...@@ -995,6 +995,13 @@ class HFTracer(Tracer):
inputs_dict[input_name] = torch.zeros( inputs_dict[input_name] = torch.zeros(
*shape, model.config.input_feat_per_channel, dtype=torch.float, device=device *shape, model.config.input_feat_per_channel, dtype=torch.float, device=device
) )
elif "inputs_embeds" in input_name:
batch_size = shape[0]
sequence_length = shape[-1]
inputs_dict[input_name] = torch.zeros(
batch_size, sequence_length, model.config.hidden_size, dtype=torch.float, device=device
)
elif "visual_feats" in input_name: elif "visual_feats" in input_name:
inputs_dict[input_name] = torch.zeros( inputs_dict[input_name] = torch.zeros(
shape shape
......
...@@ -1158,6 +1158,7 @@ class ModelTesterMixin: ...@@ -1158,6 +1158,7 @@ class ModelTesterMixin:
"input_features", "input_features",
"input_ids", "input_ids",
"input_values", "input_values",
"inputs_embeds",
"pixel_values", "pixel_values",
"token_type_ids", "token_type_ids",
"visual_feats", "visual_feats",
...@@ -1214,16 +1215,27 @@ class ModelTesterMixin: ...@@ -1214,16 +1215,27 @@ class ModelTesterMixin:
(past_mask, inputs_to_test[1]["attention_mask"]), dim=1 (past_mask, inputs_to_test[1]["attention_mask"]), dim=1
) )
if "inputs_embeds" in inspect.signature(model.forward).parameters:
inputs_to_test.append(
{
"inputs_embeds": torch.rand(
2, 2, model.config.hidden_size, dtype=torch.float, device=torch_device
)
}
)
for inps in inputs_to_test: for inps in inputs_to_test:
filtered_inputs = {k: v for (k, v) in inps.items() if k in input_names} filtered_inputs = {k: v for (k, v) in inps.items() if k in input_names}
input_names = list(filtered_inputs.keys()) input_names_to_trace = list(filtered_inputs.keys())
if model.__class__.__name__ in set(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES.values()) and ( if model.__class__.__name__ in set(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES.values()) and (
not hasattr(model.config, "problem_type") or model.config.problem_type is None not hasattr(model.config, "problem_type") or model.config.problem_type is None
): ):
model.config.problem_type = "single_label_classification" model.config.problem_type = "single_label_classification"
traced_model = symbolic_trace(model, input_names) model.config.use_cache = "past_key_values" in input_names_to_trace
traced_model = symbolic_trace(model, input_names_to_trace)
with torch.no_grad(): with torch.no_grad():
traced_output = traced_model(**filtered_inputs) traced_output = traced_model(**filtered_inputs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment