Commit b3d83d68 authored by Julien Chaumond's avatar Julien Chaumond
Browse files

Fixup 9d060314

parent 75d5f98f
...@@ -139,6 +139,9 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_ ...@@ -139,6 +139,9 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_
input_ids: torch.Tensor = roberta.encode(SAMPLE_TEXT).unsqueeze(0) # batch of size 1 input_ids: torch.Tensor = roberta.encode(SAMPLE_TEXT).unsqueeze(0) # batch of size 1
our_output = model(input_ids)[0] our_output = model(input_ids)[0]
if classification_head:
their_output = roberta.model.classification_heads['mnli'](roberta.extract_features(input_ids))
else:
their_output = roberta.model(input_ids)[0] their_output = roberta.model(input_ids)[0]
print(our_output.shape, their_output.shape) print(our_output.shape, their_output.shape)
success = torch.allclose(our_output, their_output, atol=1e-3) success = torch.allclose(our_output, their_output, atol=1e-3)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment