Unverified Commit 7f65daa2 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

fix reformer fp16 (#6237)

parent 7ea9b2db
...@@ -389,7 +389,7 @@ class ReformerModelTester: ...@@ -389,7 +389,7 @@ class ReformerModelTester:
model.to(torch_device) model.to(torch_device)
model.half() model.half()
model.eval() model.eval()
output = model(input_ids, attention_mask=input_mask)["last_input_state"] output = model(input_ids, attention_mask=input_mask)["last_hidden_state"]
self.parent.assertFalse(torch.isnan(output).any().item()) self.parent.assertFalse(torch.isnan(output).any().item())
def create_and_check_reformer_model_generate(self, config, input_ids, input_mask, choice_labels): def create_and_check_reformer_model_generate(self, config, input_ids, input_mask, choice_labels):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment