"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "35e06872560c243b09104482736d84edeecbfe04"
Unverified Commit 80bdb9c3 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

fix bart loss masking (#9131)

parent 3caba8d3
...@@ -70,6 +70,11 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int): ...@@ -70,6 +70,11 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int):
index_of_eos = (input_ids.ne(pad_token_id).sum(dim=1) - 1).unsqueeze(-1) index_of_eos = (input_ids.ne(pad_token_id).sum(dim=1) - 1).unsqueeze(-1)
prev_output_tokens[:, 0] = input_ids.gather(1, index_of_eos).squeeze() prev_output_tokens[:, 0] = input_ids.gather(1, index_of_eos).squeeze()
prev_output_tokens[:, 1:] = input_ids[:, :-1] prev_output_tokens[:, 1:] = input_ids[:, :-1]
assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined."
# replace possible -100 values in labels by `pad_token_id`
prev_output_tokens.masked_fill_(prev_output_tokens == -100, pad_token_id)
return prev_output_tokens return prev_output_tokens
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment