Unverified Commit 390cf16b authored by guillaume-be's avatar guillaume-be Committed by GitHub
Browse files

Prophetnet optimization (#9453)

* Vectorized `ngram_attention_bias` calculation

* updated formatting with black

* Further optimization

* one (last) optimization
parent 28d74872
...@@ -171,13 +171,15 @@ def ngram_attention_bias(sequence_length, ngram, device, dtype): ...@@ -171,13 +171,15 @@ def ngram_attention_bias(sequence_length, ngram, device, dtype):
""" """
This function computes the bias for the predict stream This function computes the bias for the predict stream
""" """
bias = torch.ones((ngram, sequence_length, 2 * sequence_length), device=device, dtype=dtype) * float("-inf") left_block = torch.ones((ngram, sequence_length, sequence_length), device=device, dtype=dtype) * float("-inf")
right_block = left_block.detach().clone()
# create bias # create bias
for stream_idx in range(ngram): for stream_idx in range(ngram):
for i in range(sequence_length): right_block[stream_idx].fill_diagonal_(0, wrap=False)
bias[stream_idx, i, sequence_length + i] = 0 left_block[stream_idx].triu_(-stream_idx + 1)
bias[stream_idx, i, : max(i - stream_idx, 0) + 1] = 0
return bias left_block[:, :, 0] = 0
return torch.cat([left_block, right_block], dim=2)
def compute_relative_buckets(num_buckets, max_distance, relative_positions, is_bidirectional=False): def compute_relative_buckets(num_buckets, max_distance, relative_positions, is_bidirectional=False):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment