Unverified Commit c51818c3 authored by Atream's avatar Atream Committed by GitHub
Browse files

Merge pull request #902 from kvcache-ai/rollback-triton-prefill

rollback-triton-prefill
parents bda9cf15 3934b9df
...@@ -325,27 +325,18 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention): ...@@ -325,27 +325,18 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
key_states[:, :, :, self.qk_nope_head_dim:] = k_pe.view(bsz, kv_seq_len, 1, -1) key_states[:, :, :, self.qk_nope_head_dim:] = k_pe.view(bsz, kv_seq_len, 1, -1)
value_states = value_states.view(bsz, kv_seq_len, self.num_heads, self.v_head_dim) value_states = value_states.view(bsz, kv_seq_len, self.num_heads, self.v_head_dim)
value_states_padded = torch.nn.functional.pad(value_states, [0, query_states.shape[-1] - value_states.shape[-1]], value=0)
# for bsz = 1 attn_output = flash_attn_func(
attn_output = torch.zeros(bsz * q_len, self.num_heads, self.v_head_dim, device=hidden_states.device) query_states,
b_start_loc = torch.zeros(bsz, dtype=torch.int64, device=hidden_states.device) key_states,
b_seq_len = torch.full((bsz,), q_len, dtype=torch.int64, device=hidden_states.device) value_states_padded,
softmax_scale=self.softmax_scale,
max_input_len = q_len causal=True,
context_attention_fwd(
q=query_states.squeeze(0).view(-1, self.num_heads, self.q_head_dim),
k=key_states.squeeze(0).view(-1, self.num_heads, self.q_head_dim),
v=value_states.squeeze(0).view(-1, self.num_heads, self.v_head_dim),
o=attn_output,
b_start_loc=b_start_loc,
b_seq_len=b_seq_len,
max_input_len=max_input_len,
is_causal=True
) )
if self.q_head_dim != self.v_head_dim: if self.q_head_dim != self.v_head_dim:
attn_output = attn_output[:, :, : self.v_head_dim] attn_output = attn_output[:, :, :, : self.v_head_dim]
attn_output = attn_output.reshape( attn_output = attn_output.reshape(
bsz, q_len, self.num_heads * self.v_head_dim bsz, q_len, self.num_heads * self.v_head_dim
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment