"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "ba6f6e44a89d201099142b2e4a72c0ebac74e960"
Commit b3d83d68 authored by Julien Chaumond's avatar Julien Chaumond
Browse files

Fixup 9d060314

parent 75d5f98f
...@@ -139,7 +139,10 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_ ...@@ -139,7 +139,10 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_
input_ids: torch.Tensor = roberta.encode(SAMPLE_TEXT).unsqueeze(0) # batch of size 1 input_ids: torch.Tensor = roberta.encode(SAMPLE_TEXT).unsqueeze(0) # batch of size 1
our_output = model(input_ids)[0] our_output = model(input_ids)[0]
their_output = roberta.model(input_ids)[0] if classification_head:
their_output = roberta.model.classification_heads['mnli'](roberta.extract_features(input_ids))
else:
their_output = roberta.model(input_ids)[0]
print(our_output.shape, their_output.shape) print(our_output.shape, their_output.shape)
success = torch.allclose(our_output, their_output, atol=1e-3) success = torch.allclose(our_output, their_output, atol=1e-3)
print( print(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment