Unverified Commit 9fd11bf1 authored by Henry Dashwood's avatar Henry Dashwood Committed by GitHub
Browse files

replace torch.triu with onnx compatible code (#6929)

parent ed71c21d
......@@ -154,9 +154,10 @@ def _prepare_bart_decoder_inputs(
if decoder_padding_mask is not None and decoder_padding_mask.shape[1] > 1:
# never mask leading token, even if it is pad
decoder_padding_mask[:, 0] = decoder_padding_mask[:, 1]
causal_mask = torch.triu(fill_with_neg_inf(torch.zeros(tgt_len, tgt_len)), 1).to(
dtype=causal_mask_dtype, device=decoder_input_ids.device
)
tmp = fill_with_neg_inf(torch.zeros(tgt_len, tgt_len))
mask = torch.arange(tmp.size(-1))
tmp.masked_fill_(mask < (mask + 1).view(tmp.size(-1), 1), 0)
causal_mask = tmp.to(dtype=causal_mask_dtype, device=decoder_input_ids.device)
return decoder_input_ids, decoder_padding_mask, causal_mask
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment