Unverified Commit be81269e authored by UnicornChan's avatar UnicornChan Committed by GitHub
Browse files

Merge pull request #71 from Azure-Tang/main

[fix] Fix qlen > chunk_size mask is none error
parents 022b8938 c55de02f
...@@ -195,11 +195,11 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention): ...@@ -195,11 +195,11 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
[:,:min(self.chunck_size, min(past_key_value.max_cache_len-cur_idx, self.chunck_size))] [:,:min(self.chunck_size, min(past_key_value.max_cache_len-cur_idx, self.chunck_size))]
self.attn_mask[:, :, :, cur_idx+self.chunck_size:] = -1e+38 self.attn_mask[:, :, :, cur_idx+self.chunck_size:] = -1e+38
self.attn_mask[:, :, :, :cur_idx] = 0 self.attn_mask[:, :, :, :cur_idx] = 0
chunck_mask = torch.narrow(self.attn_mask, 2, 0, min(self.chunck_size, q_len-cur_idx)) chunk_mask = torch.narrow(self.attn_mask, 2, 0, min(self.chunck_size, q_len-cur_idx))
cur_output, _, _ = self.forward_chunck( cur_output, _, _ = self.forward_chunck(
hidden_states[:, cur_idx:min(cur_idx + self.chunck_size, q_len), ...], hidden_states[:, cur_idx:min(cur_idx + self.chunck_size, q_len), ...],
chunck_mask, chunk_mask,
position_ids[:, cur_idx:min(cur_idx + self.chunck_size, q_len)], position_ids[:, cur_idx:min(cur_idx + self.chunck_size, q_len)],
past_key_value, past_key_value,
output_attentions, output_attentions,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment