Unverified Commit ac99217e authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Fix the CI (#4903)

* Fix CI
parent 0a375f5a
...@@ -67,6 +67,8 @@ class ModelTesterMixin: ...@@ -67,6 +67,8 @@ class ModelTesterMixin:
if model_class in MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values(): if model_class in MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
return { return {
k: v.unsqueeze(1).expand(-1, self.model_tester.num_choices, -1).contiguous() k: v.unsqueeze(1).expand(-1, self.model_tester.num_choices, -1).contiguous()
if isinstance(v, torch.Tensor) and v.ndim != 0
else v
for k, v in inputs_dict.items() for k, v in inputs_dict.items()
} }
return inputs_dict return inputs_dict
...@@ -157,7 +159,7 @@ class ModelTesterMixin: ...@@ -157,7 +159,7 @@ class ModelTesterMixin:
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
outputs = model(**inputs_dict) outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs[-1] attentions = outputs[-1]
self.assertEqual(model.config.output_hidden_states, False) self.assertEqual(model.config.output_hidden_states, False)
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment