Unverified Commit 7d4cfa3b authored by David Fan's avatar David Fan Committed by GitHub
Browse files

Rewrite ProphetNet to adapt converting ONNX friendly (#11981)

* Rewrite

* [ONNX] rewrite
parent c0fe3c9a
...@@ -1687,7 +1687,9 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel): ...@@ -1687,7 +1687,9 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel):
batch_size, seq_length = hidden_states.shape[:2] batch_size, seq_length = hidden_states.shape[:2]
# get causal mask # get causal mask
causal_mask = hidden_states.new(seq_length, seq_length).float().fill_(-float("inf")) causal_mask = torch.full(
(seq_length, seq_length), -float("inf"), dtype=hidden_states.dtype, device=hidden_states.device
)
causal_mask = torch.triu(causal_mask, 1) causal_mask = torch.triu(causal_mask, 1)
extended_causal_mask = causal_mask[:seq_length, :seq_length][None, :, :].expand( extended_causal_mask = causal_mask[:seq_length, :seq_length][None, :, :].expand(
(batch_size,) + causal_mask.shape (batch_size,) + causal_mask.shape
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment