Commit 7ed5bf70 authored by thomwolf's avatar thomwolf
Browse files

add tests

parent 70887795
......@@ -31,6 +31,14 @@ def _config_zero_init(config):
setattr(configs_no_init, key, 0.0)
return configs_no_init
def _create_and_check_torchscript(tester, model_classes, config, inputs_dict):
configs_no_init = _config_zero_init(config) # To be sure we have no Nan
for model_class in model_classes:
model = model_class(config=configs_no_init)
model.eval()
inputs = inputs_dict['input_ids'] # Let's keep only input_ids
traced_model = torch.jit.trace(model, inputs)
def _create_and_check_initialization(tester, model_classes, config, inputs_dict):
configs_no_init = _config_zero_init(config)
for model_class in model_classes:
......@@ -39,7 +47,7 @@ def _create_and_check_initialization(tester, model_classes, config, inputs_dict)
tester.parent.assertIn(param.data.mean().item(), [0.0, 1.0], msg="Parameter {} of model {} seems not properly initialized".format(name, model_class))
def _create_and_check_for_headmasking(tester, model_classes, config, inputs_dict):
configs_no_init = _config_zero_init(config)
configs_no_init = _config_zero_init(config) # To be sure we have no Nan
for model_class in model_classes:
config.output_attentions = True
config.output_hidden_states = True
......@@ -155,6 +163,7 @@ def _create_and_check_for_hidden_states(tester, model_classes, config, inputs_di
def create_and_check_commons(tester, config, inputs_dict, test_pruning=True):
_create_and_check_initialization(tester, tester.all_model_classes, config, inputs_dict)
_create_and_check_torchscript(tester, tester.all_model_classes, config, inputs_dict)
_create_and_check_for_attentions(tester, tester.all_model_classes, config, inputs_dict)
_create_and_check_for_headmasking(tester, tester.all_model_classes, config, inputs_dict)
_create_and_check_for_hidden_states(tester, tester.all_model_classes, config, inputs_dict)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment