"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "269e73b6011f87f4d5e6fea47f8fee11dfcdf2cc"
Commit 6ce1ee04 authored by LysandreJik's avatar LysandreJik
Browse files

TorchScript testing with output_attentions and output_hidden_state

parent 7ed5bf70
...@@ -31,13 +31,50 @@ def _config_zero_init(config): ...@@ -31,13 +31,50 @@ def _config_zero_init(config):
setattr(configs_no_init, key, 0.0) setattr(configs_no_init, key, 0.0)
return configs_no_init return configs_no_init
def _create_and_check_torchscript_output_attentions(tester, model_classes, config, inputs_dict):
config.output_attentions = True
_create_and_check_torchscript(tester, model_classes, config, inputs_dict)
def _create_and_check_torchscript_output_hidden_state(tester, model_classes, config, inputs_dict):
config.output_hidden_states = True
_create_and_check_torchscript(tester, model_classes, config, inputs_dict)
def _create_and_check_torchscript(tester, model_classes, config, inputs_dict): def _create_and_check_torchscript(tester, model_classes, config, inputs_dict):
configs_no_init = _config_zero_init(config) # To be sure we have no Nan configs_no_init = _config_zero_init(config) # To be sure we have no Nan
for model_class in model_classes: for model_class in model_classes:
model = model_class(config=configs_no_init) model = model_class(config=configs_no_init)
model.eval() model.eval()
inputs = inputs_dict['input_ids'] # Let's keep only input_ids inputs = inputs_dict['input_ids'] # Let's keep only input_ids
traced_model = torch.jit.trace(model, inputs)
try:
torch.jit.trace(model, inputs)
except RuntimeError:
tester.parent.fail("Couldn't trace module.")
try:
traced_gpt2 = torch.jit.trace(model, inputs)
torch.jit.save(traced_gpt2, "traced_model.pt")
except RuntimeError:
tester.parent.fail("Couldn't save module.")
try:
loaded_model = torch.jit.load("traced_model.pt")
os.remove("traced_model.pt")
except ValueError:
tester.parent.fail("Couldn't load module.")
model.eval()
loaded_model.eval()
model_params = model.parameters()
loaded_model_params = loaded_model.parameters()
models_equal = True
for p1, p2 in zip(model_params, loaded_model_params):
if p1.data.ne(p2.data).sum() > 0:
models_equal = False
tester.parent.assertTrue(models_equal)
def _create_and_check_initialization(tester, model_classes, config, inputs_dict): def _create_and_check_initialization(tester, model_classes, config, inputs_dict):
configs_no_init = _config_zero_init(config) configs_no_init = _config_zero_init(config)
...@@ -164,6 +201,8 @@ def _create_and_check_for_hidden_states(tester, model_classes, config, inputs_di ...@@ -164,6 +201,8 @@ def _create_and_check_for_hidden_states(tester, model_classes, config, inputs_di
def create_and_check_commons(tester, config, inputs_dict, test_pruning=True): def create_and_check_commons(tester, config, inputs_dict, test_pruning=True):
_create_and_check_initialization(tester, tester.all_model_classes, config, inputs_dict) _create_and_check_initialization(tester, tester.all_model_classes, config, inputs_dict)
_create_and_check_torchscript(tester, tester.all_model_classes, config, inputs_dict) _create_and_check_torchscript(tester, tester.all_model_classes, config, inputs_dict)
_create_and_check_torchscript_output_attentions(tester, tester.all_model_classes, config, inputs_dict)
_create_and_check_torchscript_output_hidden_state(tester, tester.all_model_classes, config, inputs_dict)
_create_and_check_for_attentions(tester, tester.all_model_classes, config, inputs_dict) _create_and_check_for_attentions(tester, tester.all_model_classes, config, inputs_dict)
_create_and_check_for_headmasking(tester, tester.all_model_classes, config, inputs_dict) _create_and_check_for_headmasking(tester, tester.all_model_classes, config, inputs_dict)
_create_and_check_for_hidden_states(tester, tester.all_model_classes, config, inputs_dict) _create_and_check_for_hidden_states(tester, tester.all_model_classes, config, inputs_dict)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment