Unverified Commit ac99217e authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Fix the CI (#4903)

* Fix CI
parent 0a375f5a
......@@ -67,6 +67,8 @@ class ModelTesterMixin:
if model_class in MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
return {
k: v.unsqueeze(1).expand(-1, self.model_tester.num_choices, -1).contiguous()
if isinstance(v, torch.Tensor) and v.ndim != 0
else v
for k, v in inputs_dict.items()
}
return inputs_dict
......@@ -157,7 +159,7 @@ class ModelTesterMixin:
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**inputs_dict)
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs[-1]
self.assertEqual(model.config.output_hidden_states, False)
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment