Unverified Commit d23d2c27 authored by jiqing-feng's avatar jiqing-feng Committed by GitHub
Browse files

Represent query_length in a different way to solve jit issue (#25164)

Fix jit trace
parent 2a787201
...@@ -154,9 +154,7 @@ class MptAttention(nn.Module): ...@@ -154,9 +154,7 @@ class MptAttention(nn.Module):
attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2)) * self.softmax_scale attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2)) * self.softmax_scale
query_length = seq_length query_length = seq_length if past_key_value is None else seq_length + past_key_value[0].shape[2]
if past_key_value is not None:
query_length += past_key_value[0].shape[2]
if position_bias is not None: if position_bias is not None:
if len(position_bias.shape) != 3: if len(position_bias.shape) != 3:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment