"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "34a3c25a3068ab5cdbecb08ddf2866f1209fd2dd"
Unverified Commit d018622d authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

correct mistake in order (#9134)

parent 80bdb9c3
...@@ -67,14 +67,15 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int): ...@@ -67,14 +67,15 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int):
Shift input ids one token to the right, and wrap the last non pad token (usually <eos>). Shift input ids one token to the right, and wrap the last non pad token (usually <eos>).
""" """
prev_output_tokens = input_ids.clone() prev_output_tokens = input_ids.clone()
index_of_eos = (input_ids.ne(pad_token_id).sum(dim=1) - 1).unsqueeze(-1)
prev_output_tokens[:, 0] = input_ids.gather(1, index_of_eos).squeeze()
prev_output_tokens[:, 1:] = input_ids[:, :-1]
assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined." assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined."
# replace possible -100 values in labels by `pad_token_id` # replace possible -100 values in labels by `pad_token_id`
prev_output_tokens.masked_fill_(prev_output_tokens == -100, pad_token_id) prev_output_tokens.masked_fill_(prev_output_tokens == -100, pad_token_id)
index_of_eos = (input_ids.ne(pad_token_id).sum(dim=1) - 1).unsqueeze(-1)
prev_output_tokens[:, 0] = input_ids.gather(1, index_of_eos).squeeze()
prev_output_tokens[:, 1:] = input_ids[:, :-1]
return prev_output_tokens return prev_output_tokens
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment