"tests/vscode:/vscode.git/clone" did not exist on "56b03c96b865a40811f4eb2942e71aaab4cd38c2"
Unverified Commit d23d2c27 authored by jiqing-feng's avatar jiqing-feng Committed by GitHub
Browse files

Represent query_length in a different way to solve jit issue (#25164)

Fix jit trace
parent 2a787201
......@@ -154,9 +154,7 @@ class MptAttention(nn.Module):
attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2)) * self.softmax_scale
query_length = seq_length
if past_key_value is not None:
query_length += past_key_value[0].shape[2]
query_length = seq_length if past_key_value is None else seq_length + past_key_value[0].shape[2]
if position_bias is not None:
if len(position_bias.shape) != 3:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment