"vscode:/vscode.git/clone" did not exist on "c45d8ac55439decd059d697e21daf27e85ac3412"
Unverified Commit 32d20314 authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

[fix] slow fill_mask test failure (#5406)

parent 80aa4b8a
......@@ -827,6 +827,7 @@ class FillMaskPipeline(Pipeline):
values, predictions = topk.values.numpy(), topk.indices.numpy()
else:
masked_index = (input_ids == self.tokenizer.mask_token_id).nonzero().item()
logits = outputs[i, masked_index, :]
probs = logits.softmax(dim=0)
values, predictions = probs.topk(self.topk)
......
......@@ -31,12 +31,12 @@ TF_TRANSLATION_FINETUNED_MODELS = [("patrickvonplaten/t5-tiny-random", "translat
expected_fill_mask_result = [
[
{"sequence": "<s> My name is:</s>", "score": 0.009954338893294334, "token": 35},
{"sequence": "<s> My name is John</s>", "score": 0.0080940006300807, "token": 610},
{"sequence": "<s>My name is John</s>", "score": 0.00782308354973793, "token": 610, "token_str": "ĠJohn"},
{"sequence": "<s>My name is Chris</s>", "score": 0.007475061342120171, "token": 1573, "token_str": "ĠChris"},
],
[
{"sequence": "<s> The largest city in France is Paris</s>", "score": 0.3185044229030609, "token": 2201},
{"sequence": "<s> The largest city in France is Lyon</s>", "score": 0.21112334728240967, "token": 12790},
{"sequence": "<s>The largest city in France is Paris</s>", "score": 0.3185044229030609, "token": 2201},
{"sequence": "<s>The largest city in France is Lyon</s>", "score": 0.21112334728240967, "token": 12790},
],
]
SUMMARIZATION_KWARGS = dict(num_beams=2, min_length=2, max_length=5)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment