"docs/source/ja/index.md" did not exist on "4d367a3c8101bd72a711445d305005133ad15e73"
Unverified Commit 7d4cfa3b authored by David Fan's avatar David Fan Committed by GitHub
Browse files

Rewrite ProphetNet to adapt converting ONNX friendly (#11981)

* Rewrite

* [ONNX] rewrite
parent c0fe3c9a
......@@ -1687,7 +1687,9 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel):
batch_size, seq_length = hidden_states.shape[:2]
# get causal mask
causal_mask = hidden_states.new(seq_length, seq_length).float().fill_(-float("inf"))
causal_mask = torch.full(
(seq_length, seq_length), -float("inf"), dtype=hidden_states.dtype, device=hidden_states.device
)
causal_mask = torch.triu(causal_mask, 1)
extended_causal_mask = causal_mask[:seq_length, :seq_length][None, :, :].expand(
(batch_size,) + causal_mask.shape
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment