"sgl-router/src/git@developer.sourcefind.cn:change/sglang.git" did not exist on "1a8706c8b94918915bcaa44ddbc9e29a0cfea3b2"
Commit 7ed5bf70 authored by thomwolf's avatar thomwolf
Browse files

add tests

parent 70887795
...@@ -31,6 +31,14 @@ def _config_zero_init(config): ...@@ -31,6 +31,14 @@ def _config_zero_init(config):
setattr(configs_no_init, key, 0.0) setattr(configs_no_init, key, 0.0)
return configs_no_init return configs_no_init
def _create_and_check_torchscript(tester, model_classes, config, inputs_dict):
configs_no_init = _config_zero_init(config) # To be sure we have no Nan
for model_class in model_classes:
model = model_class(config=configs_no_init)
model.eval()
inputs = inputs_dict['input_ids'] # Let's keep only input_ids
traced_model = torch.jit.trace(model, inputs)
def _create_and_check_initialization(tester, model_classes, config, inputs_dict): def _create_and_check_initialization(tester, model_classes, config, inputs_dict):
configs_no_init = _config_zero_init(config) configs_no_init = _config_zero_init(config)
for model_class in model_classes: for model_class in model_classes:
...@@ -39,7 +47,7 @@ def _create_and_check_initialization(tester, model_classes, config, inputs_dict) ...@@ -39,7 +47,7 @@ def _create_and_check_initialization(tester, model_classes, config, inputs_dict)
tester.parent.assertIn(param.data.mean().item(), [0.0, 1.0], msg="Parameter {} of model {} seems not properly initialized".format(name, model_class)) tester.parent.assertIn(param.data.mean().item(), [0.0, 1.0], msg="Parameter {} of model {} seems not properly initialized".format(name, model_class))
def _create_and_check_for_headmasking(tester, model_classes, config, inputs_dict): def _create_and_check_for_headmasking(tester, model_classes, config, inputs_dict):
configs_no_init = _config_zero_init(config) configs_no_init = _config_zero_init(config) # To be sure we have no Nan
for model_class in model_classes: for model_class in model_classes:
config.output_attentions = True config.output_attentions = True
config.output_hidden_states = True config.output_hidden_states = True
...@@ -155,6 +163,7 @@ def _create_and_check_for_hidden_states(tester, model_classes, config, inputs_di ...@@ -155,6 +163,7 @@ def _create_and_check_for_hidden_states(tester, model_classes, config, inputs_di
def create_and_check_commons(tester, config, inputs_dict, test_pruning=True): def create_and_check_commons(tester, config, inputs_dict, test_pruning=True):
_create_and_check_initialization(tester, tester.all_model_classes, config, inputs_dict) _create_and_check_initialization(tester, tester.all_model_classes, config, inputs_dict)
_create_and_check_torchscript(tester, tester.all_model_classes, config, inputs_dict)
_create_and_check_for_attentions(tester, tester.all_model_classes, config, inputs_dict) _create_and_check_for_attentions(tester, tester.all_model_classes, config, inputs_dict)
_create_and_check_for_headmasking(tester, tester.all_model_classes, config, inputs_dict) _create_and_check_for_headmasking(tester, tester.all_model_classes, config, inputs_dict)
_create_and_check_for_hidden_states(tester, tester.all_model_classes, config, inputs_dict) _create_and_check_for_hidden_states(tester, tester.all_model_classes, config, inputs_dict)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment