Commit 9c3190d0 authored by zhuwenwen's avatar zhuwenwen
Browse files

set vdim=128

parent ec2e17d8
......@@ -533,39 +533,37 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
# For MLA the v head dim is smaller than qk head dim so we pad out
# v with 0s to match the qk head dim
# v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]],
# value=0)
v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]],
value=0)
if torch.cuda.get_device_properties(torch.cuda.current_device()).multi_processor_count == 120:
attn_output = flash_attn_varlen_func(
q=q,
k=k,
v=v_padded,
cu_seqlens_q=seq_start_loc,
cu_seqlens_k=seq_start_loc,
max_seqlen_q=max_prefill_seq_len,
max_seqlen_k=max_prefill_seq_len,
softmax_scale=self.scale,
causal=True,
)
attn_output = attn_output\
.view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\
.reshape(-1, self.num_heads * v.shape[-1])
else:
attn_output = flash_attn_varlen_func(
q=q,
k=k,
v=v,
cu_seqlens_q=seq_start_loc,
cu_seqlens_k=seq_start_loc,
max_seqlen_q=max_prefill_seq_len,
max_seqlen_k=max_prefill_seq_len,
softmax_scale=self.scale,
causal=True,
)
attn_output = attn_output\
.reshape(-1, self.num_heads * v.shape[-1])
# if torch.cuda.get_device_properties(torch.cuda.current_device()).multi_processor_count == 120:
# attn_output = flash_attn_varlen_func(
# q=q,
# k=k,
# v=v_padded,
# cu_seqlens_q=seq_start_loc,
# cu_seqlens_k=seq_start_loc,
# max_seqlen_q=max_prefill_seq_len,
# max_seqlen_k=max_prefill_seq_len,
# softmax_scale=self.scale,
# causal=True,
# )
# attn_output = attn_output\
# .view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\
# .reshape(-1, self.num_heads * v.shape[-1])
# else:
attn_output = flash_attn_varlen_func(
q=q,
k=k,
v=v,
cu_seqlens_q=seq_start_loc,
cu_seqlens_k=seq_start_loc,
max_seqlen_q=max_prefill_seq_len,
max_seqlen_k=max_prefill_seq_len,
softmax_scale=self.scale,
causal=True,
)
attn_output = attn_output\
.reshape(-1, self.num_heads * v.shape[-1])
return self.o_proj(attn_output)[0]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment