Unverified Commit 5c009186 authored by Kiran R's avatar Kiran R Committed by GitHub
Browse files

added support for exporting of t5 to onnx with past_key_values (#10651)

parent 50f4539b
......@@ -423,6 +423,8 @@ class T5Attention(nn.Module):
# past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head)
batch_size, seq_length = hidden_states.shape[:2]
int_seq_length = int(seq_length)
real_seq_length = seq_length
if past_key_value is not None:
......@@ -489,7 +491,7 @@ class T5Attention(nn.Module):
# if key and values are already calculated
# we want only the last query position bias
if past_key_value is not None:
position_bias = position_bias[:, :, -seq_length:, :]
position_bias = position_bias[:, :, -int_seq_length:, :]
if mask is not None:
position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment