Commit b01efa0b authored by zhuwenwen's avatar zhuwenwen
Browse files

remove unused mla utils.py

parent 4a19cdf5
...@@ -1297,7 +1297,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): ...@@ -1297,7 +1297,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
output = self.flash_attn_varlen_func( output = self.flash_attn_varlen_func(
q=q, q=q,
k=k, k=k,
v=v_padded, v=v,
cu_seqlens_q=prefill_metadata.query_start_loc, cu_seqlens_q=prefill_metadata.query_start_loc,
cu_seqlens_k=prefill_metadata.query_start_loc, cu_seqlens_k=prefill_metadata.query_start_loc,
max_seqlen_q=prefill_metadata.max_prefill_seq_len, max_seqlen_q=prefill_metadata.max_prefill_seq_len,
...@@ -1323,8 +1323,10 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): ...@@ -1323,8 +1323,10 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
) )
# slice by `:v.shape[-1]` in order to remove v headdim padding # slice by `:v.shape[-1]` in order to remove v headdim padding
# output = output\
# .view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\
# .reshape(-1, self.num_heads * v.shape[-1])
output = output\ output = output\
.view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\
.reshape(-1, self.num_heads * v.shape[-1]) .reshape(-1, self.num_heads * v.shape[-1])
return self.o_proj(output)[0] return self.o_proj(output)[0]
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment