Commit 256749c9 authored by liuchy5's avatar liuchy5
Browse files

feat:flash_mla,q去掉pad

parent adbd3d7b
...@@ -924,14 +924,14 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]): ...@@ -924,14 +924,14 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
padded_num_heads = self.fp8_decode_padded_heads padded_num_heads = self.fp8_decode_padded_heads
# Pad query if needed (kernel only supports h_q = 64 or 128) # Pad query if needed (kernel only supports h_q = 64 or 128)
if actual_num_heads < padded_num_heads: #if actual_num_heads < padded_num_heads:
logger.warning_once( # logger.warning_once(
f"Padding num_heads from {actual_num_heads} to " # f"Padding num_heads from {actual_num_heads} to "
f"{padded_num_heads} for FP8 sparse decode kernel" # f"{padded_num_heads} for FP8 sparse decode kernel"
) # )
q_padded = q.new_zeros((q.size(0), q.size(1), padded_num_heads, q.size(3))) # q_padded = q.new_zeros((q.size(0), q.size(1), padded_num_heads, q.size(3)))
q_padded[:, :, :actual_num_heads, :] = q # q_padded[:, :, :actual_num_heads, :] = q
q = q_padded # q = q_padded
out, lse = flash_mla_with_kvcache( out, lse = flash_mla_with_kvcache(
q=q, q=q,
...@@ -946,8 +946,8 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]): ...@@ -946,8 +946,8 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
) )
# Slice output back to actual head count if we padded # Slice output back to actual head count if we padded
if actual_num_heads < padded_num_heads: #if actual_num_heads < padded_num_heads:
out = out[:, :, :actual_num_heads, :] # out = out[:, :, :actual_num_heads, :]
return out, lse return out, lse
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment