"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "7533d30acd975027e83a548e4c38e06fa335291b"
Unverified Commit 18ecd36f authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Fix Bart Shift (#9135)

* correct mistake in order

* fix tensor copy

* clone tensor correctly
parent d018622d
......@@ -72,9 +72,10 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int):
# replace possible -100 values in labels by `pad_token_id`
prev_output_tokens.masked_fill_(prev_output_tokens == -100, pad_token_id)
index_of_eos = (input_ids.ne(pad_token_id).sum(dim=1) - 1).unsqueeze(-1)
prev_output_tokens[:, 0] = input_ids.gather(1, index_of_eos).squeeze()
prev_output_tokens[:, 1:] = input_ids[:, :-1]
index_of_eos = (prev_output_tokens.ne(pad_token_id).sum(dim=1) - 1).unsqueeze(-1)
decoder_start_tokens = prev_output_tokens.gather(1, index_of_eos).squeeze()
prev_output_tokens[:, 1:] = prev_output_tokens[:, :-1].clone()
prev_output_tokens[:, 0] = decoder_start_tokens
return prev_output_tokens
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment