Unverified Commit 0fe17f37 authored by Michael Benayoun's avatar Michael Benayoun Committed by GitHub
Browse files

FX tracing improvement (#14321)

* Change the way tracing happens, enabling dynamic axes out of the box

* Update the tests and modeling xlnet

* Add the non recoding of leaf modules to avoid recording more values for the methods to record than what will be seen at tracing time (which would otherwise desynchronize the recorded values and the values that need to be given to the proxies during tracing, causing errors).

* Comments and making tracing work for gpt-j and xlnet

* Refactore things related to num_choices (and batch_size, sequence_length)

* Update fx to work on PyTorch 1.10

* Postpone autowrap_function feature usage for later

* Add copyrights

* Remove unnecessary file

* Fix issue with add_new_model_like

* Apply suggestions
parent 552f8d30
...@@ -116,8 +116,7 @@ class ModelTesterMixin: ...@@ -116,8 +116,7 @@ class ModelTesterMixin:
model_tester = None model_tester = None
all_model_classes = () all_model_classes = ()
all_generative_model_classes = () all_generative_model_classes = ()
fx_ready_model_classes = () fx_compatible = False
fx_dynamic_ready_model_classes = ()
test_torchscript = True test_torchscript = True
test_pruning = True test_pruning = True
test_resize_embeddings = True test_resize_embeddings = True
...@@ -666,19 +665,14 @@ class ModelTesterMixin: ...@@ -666,19 +665,14 @@ class ModelTesterMixin:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
self._create_and_check_torch_fx_tracing(config, inputs_dict, output_loss=True) self._create_and_check_torch_fx_tracing(config, inputs_dict, output_loss=True)
def test_torch_fx_dynamic_axes(self): def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=False):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() if not is_torch_fx_available() or not self.fx_compatible:
self._create_and_check_torch_fx_tracing(config, inputs_dict, dynamic_axes=True)
def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=False, dynamic_axes=False):
if not is_torch_fx_available():
return return
configs_no_init = _config_zero_init(config) # To be sure we have no Nan configs_no_init = _config_zero_init(config) # To be sure we have no Nan
configs_no_init.return_dict = False configs_no_init.return_dict = False
model_classes = self.fx_ready_model_classes if not dynamic_axes else self.fx_dynamic_ready_model_classes for model_class in self.all_model_classes:
for model_class in model_classes:
model = model_class(config=configs_no_init) model = model_class(config=configs_no_init)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
...@@ -687,8 +681,6 @@ class ModelTesterMixin: ...@@ -687,8 +681,6 @@ class ModelTesterMixin:
try: try:
if model.config.is_encoder_decoder: if model.config.is_encoder_decoder:
model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward
input_ids = inputs["input_ids"]
decoder_attention_mask = inputs["decoder_attention_mask"]
labels = inputs.get("labels", None) labels = inputs.get("labels", None)
input_names = ["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask"] input_names = ["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask"]
if labels is not None: if labels is not None:
...@@ -697,17 +689,7 @@ class ModelTesterMixin: ...@@ -697,17 +689,7 @@ class ModelTesterMixin:
model_output = model(**filtered_inputs) model_output = model(**filtered_inputs)
batch_size = input_ids.shape[0] traced_model = symbolic_trace(model, input_names)
encoder_sequence_length = input_ids.shape[1]
decoder_sequence_length = decoder_attention_mask.shape[1]
traced_model = symbolic_trace(
model,
input_names,
batch_size=batch_size if not dynamic_axes else -1,
sequence_length=[encoder_sequence_length, decoder_sequence_length] if not dynamic_axes else -1,
)
traced_output = traced_model(**filtered_inputs) traced_output = traced_model(**filtered_inputs)
else: else:
input_names = ["input_ids", "attention_mask", "token_type_ids"] input_names = ["input_ids", "attention_mask", "token_type_ids"]
...@@ -729,23 +711,12 @@ class ModelTesterMixin: ...@@ -729,23 +711,12 @@ class ModelTesterMixin:
model_output = model(**filtered_inputs) model_output = model(**filtered_inputs)
rank = len(input_ids.shape) rank = len(input_ids.shape)
if rank == 2: if rank not in [2, 3]:
batch_size, sequence_length = input_ids.shape
num_choices = -1
elif rank == 3:
batch_size, num_choices, sequence_length = input_ids.shape
else:
raise NotImplementedError( raise NotImplementedError(
f"symbolic_trace automatic parameters inference not implemented for input of rank {rank}." f"symbolic_trace automatic parameters inference not implemented for input of rank {rank}."
) )
traced_model = symbolic_trace( traced_model = symbolic_trace(model, input_names)
model,
input_names,
batch_size=batch_size if not dynamic_axes else -1,
sequence_length=sequence_length if not dynamic_axes else -1,
num_choices=num_choices,
)
traced_output = traced_model(**filtered_inputs) traced_output = traced_model(**filtered_inputs)
except RuntimeError: except RuntimeError:
......
...@@ -209,8 +209,7 @@ class DistilBertModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -209,8 +209,7 @@ class DistilBertModelTest(ModelTesterMixin, unittest.TestCase):
if is_torch_available() if is_torch_available()
else None else None
) )
fx_ready_model_classes = all_model_classes fx_compatible = True
fx_dynamic_ready_model_classes = all_model_classes
test_pruning = True test_pruning = True
test_torchscript = True test_torchscript = True
test_resize_embeddings = True test_resize_embeddings = True
......
...@@ -369,10 +369,7 @@ class ElectraModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -369,10 +369,7 @@ class ElectraModelTest(ModelTesterMixin, unittest.TestCase):
if is_torch_available() if is_torch_available()
else () else ()
) )
all_generative_model_classes = (ElectraForCausalLM,) if is_torch_available() else () fx_compatible = True
fx_ready_model_classes = all_model_classes
fx_dynamic_ready_model_classes = all_model_classes
# special case for ForPreTraining model # special case for ForPreTraining model
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
......
...@@ -433,7 +433,7 @@ class GPT2ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): ...@@ -433,7 +433,7 @@ class GPT2ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
) )
all_generative_model_classes = (GPT2LMHeadModel, GPT2DoubleHeadsModel) if is_torch_available() else () all_generative_model_classes = (GPT2LMHeadModel, GPT2DoubleHeadsModel) if is_torch_available() else ()
all_parallelizable_model_classes = (GPT2LMHeadModel, GPT2DoubleHeadsModel) if is_torch_available() else () all_parallelizable_model_classes = (GPT2LMHeadModel, GPT2DoubleHeadsModel) if is_torch_available() else ()
fx_ready_model_classes = all_model_classes fx_compatible = True
test_missing_keys = False test_missing_keys = False
test_model_parallel = True test_model_parallel = True
......
...@@ -372,7 +372,7 @@ class GPTNeoModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase ...@@ -372,7 +372,7 @@ class GPTNeoModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase
(GPTNeoModel, GPTNeoForCausalLM, GPTNeoForSequenceClassification) if is_torch_available() else () (GPTNeoModel, GPTNeoForCausalLM, GPTNeoForSequenceClassification) if is_torch_available() else ()
) )
all_generative_model_classes = (GPTNeoForCausalLM,) if is_torch_available() else () all_generative_model_classes = (GPTNeoForCausalLM,) if is_torch_available() else ()
fx_ready_model_classes = all_model_classes fx_compatible = True
test_missing_keys = False test_missing_keys = False
test_pruning = False test_pruning = False
test_model_parallel = False test_model_parallel = False
......
...@@ -363,7 +363,7 @@ class GPTJModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): ...@@ -363,7 +363,7 @@ class GPTJModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
else () else ()
) )
all_generative_model_classes = (GPTJForCausalLM,) if is_torch_available() else () all_generative_model_classes = (GPTJForCausalLM,) if is_torch_available() else ()
fx_ready_model_classes = all_model_classes fx_compatible = True
test_pruning = False test_pruning = False
test_missing_keys = False test_missing_keys = False
test_model_parallel = False test_model_parallel = False
......
...@@ -283,9 +283,7 @@ class MegatronBertModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -283,9 +283,7 @@ class MegatronBertModelTest(ModelTesterMixin, unittest.TestCase):
if is_torch_available() if is_torch_available()
else () else ()
) )
fx_ready_model_classes = all_model_classes fx_compatible = True
fx_dynamic_ready_model_classes = all_model_classes
# test_resize_embeddings = False # test_resize_embeddings = False
test_head_masking = False test_head_masking = False
......
...@@ -269,8 +269,7 @@ class MobileBertModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -269,8 +269,7 @@ class MobileBertModelTest(ModelTesterMixin, unittest.TestCase):
if is_torch_available() if is_torch_available()
else () else ()
) )
fx_ready_model_classes = all_model_classes fx_compatible = True
fx_dynamic_ready_model_classes = all_model_classes
# special case for ForPreTraining model # special case for ForPreTraining model
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
......
...@@ -356,6 +356,7 @@ class RobertaModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas ...@@ -356,6 +356,7 @@ class RobertaModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
else () else ()
) )
all_generative_model_classes = (RobertaForCausalLM,) if is_torch_available() else () all_generative_model_classes = (RobertaForCausalLM,) if is_torch_available() else ()
fx_compatible = True
def setUp(self): def setUp(self):
self.model_tester = RobertaModelTester(self) self.model_tester = RobertaModelTester(self)
......
...@@ -509,7 +509,7 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): ...@@ -509,7 +509,7 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (T5Model, T5ForConditionalGeneration) if is_torch_available() else () all_model_classes = (T5Model, T5ForConditionalGeneration) if is_torch_available() else ()
all_generative_model_classes = (T5ForConditionalGeneration,) if is_torch_available() else () all_generative_model_classes = (T5ForConditionalGeneration,) if is_torch_available() else ()
fx_ready_model_classes = all_model_classes fx_compatible = True
all_parallelizable_model_classes = (T5Model, T5ForConditionalGeneration) if is_torch_available() else () all_parallelizable_model_classes = (T5Model, T5ForConditionalGeneration) if is_torch_available() else ()
test_pruning = False test_pruning = False
test_torchscript = True test_torchscript = True
......
...@@ -526,6 +526,7 @@ class XLNetModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase) ...@@ -526,6 +526,7 @@ class XLNetModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase)
all_generative_model_classes = ( all_generative_model_classes = (
(XLNetLMHeadModel,) if is_torch_available() else () (XLNetLMHeadModel,) if is_torch_available() else ()
) # TODO (PVP): Check other models whether language generation is also applicable ) # TODO (PVP): Check other models whether language generation is also applicable
test_pruning = False test_pruning = False
# XLNet has 2 QA models -> need to manually set the correct labels for one of them here # XLNet has 2 QA models -> need to manually set the correct labels for one of them here
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment