Commit cc2dca96 authored by SAC_fanth's avatar SAC_fanth
Browse files

修复mtp和VLLM_USE_TRITON_CAT不能一起开的bug

parent f872f0ad
...@@ -167,7 +167,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): ...@@ -167,7 +167,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
assert attn_metadata.decode is not None assert attn_metadata.decode is not None
if envs.VLLM_USE_TRITON_CAT: if envs.VLLM_USE_TRITON_CAT:
if q_nope.shape[0] <= 1024: if q_nope.shape[0] < 1024:
q = concat_helper_decode(q_nope, q_pe, dim=2)\ q = concat_helper_decode(q_nope, q_pe, dim=2)\
.unsqueeze(1) .unsqueeze(1)
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment