Commit c6a45c08 authored by yangql's avatar yangql
Browse files

关闭sparse_mla的num_head到64/128的pad,以及添加控制fp8_use_mixed_batch模式的环境变量控制,FP8_USE_MI...

关闭sparse_mla的num_head到64/128的pad,以及添加控制fp8_use_mixed_batch模式的环境变量控制,FP8_USE_MIXED_BATCH,默认为false,为分离模式
parent 656944ac
......@@ -297,6 +297,7 @@ if TYPE_CHECKING:
VLLM_USE_FUSED_RMS_ROPE: bool = False
VLLM_USE_FUSED_FILL_RMS_CAT: bool = False
VLLM_USE_CAT_MLA: bool = False
FP8_USE_MIXED_BATCH: bool = False
VLLM_W8A8_BACKEND: int = 3
VLLM_USE_PP_BALANCE = True
VLLM_MOE_ROUTER_CAPTURE: bool = False
......@@ -1825,7 +1826,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
# vllm will use fused cat and mla
"VLLM_USE_CAT_MLA":
lambda: (os.getenv('VLLM_USE_CAT_MLA', 'False').lower() in
("true", "1")),
("true", "1")),
# vllm will use fused cat and mla
"FP8_USE_MIXED_BATCH":
lambda: (os.getenv('FP8_USE_MIXED_BATCH', 'False').lower() in
("true", "1")),
# vLLM will use opt merge_aatn_states,not triton
"VLLM_USE_MERGE_ATTN_STATES_OPT":
lambda: (os.environ.get("VLLM_USE_MERGE_ATTN_STATES_OPT", "True").lower() in
......
......@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, ClassVar, Optional
import numpy as np
import torch
from vllm import envs
from vllm import _custom_ops as ops
from vllm.config import VllmConfig, get_current_vllm_config
......@@ -668,7 +669,7 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
| FlashMLASparseMetadata.FP8KernelMetadata
| None
) = None
fp8_use_mixed_batch = self.num_heads < MIN_HEADS_FOR_BF16_PREFILL
fp8_use_mixed_batch = self.num_heads < MIN_HEADS_FOR_BF16_PREFILL and envs.FP8_USE_MIXED_BATCH
if self.use_fp8_kv_cache:
if fp8_use_mixed_batch:
fp8_extra_metadata = self._build_fp8_mixed_decode_prefill(cm)
......@@ -924,14 +925,14 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
padded_num_heads = self.fp8_decode_padded_heads
# Pad query if needed (kernel only supports h_q = 64 or 128)
if actual_num_heads < padded_num_heads:
logger.warning_once(
f"Padding num_heads from {actual_num_heads} to "
f"{padded_num_heads} for FP8 sparse decode kernel"
)
q_padded = q.new_zeros((q.size(0), q.size(1), padded_num_heads, q.size(3)))
q_padded[:, :, :actual_num_heads, :] = q
q = q_padded
# if actual_num_heads < padded_num_heads:
# logger.warning_once(
# f"Padding num_heads from {actual_num_heads} to "
# f"{padded_num_heads} for FP8 sparse decode kernel"
# )
# q_padded = q.new_zeros((q.size(0), q.size(1), padded_num_heads, q.size(3)))
# q_padded[:, :, :actual_num_heads, :] = q
# q = q_padded
out, lse = flash_mla_with_kvcache(
q=q,
......@@ -964,15 +965,15 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
# NOTE(Chen): kernel requires num_local_head to be a multiple of
# 64 on hopper and 128 on blackwell
if self.num_heads % self.prefill_padding != 0:
assert self.prefill_padding % self.num_heads == 0
logger.warning_once(
f"Padding num_heads from {self.num_heads} to "
f"{self.prefill_padding} for BF16 sparse prefill kernel"
)
q_padded = q.new_empty((q.shape[0], self.prefill_padding, q.shape[2]))
q_padded[:, : self.num_heads, :] = q
q = q_padded
# if self.num_heads % self.prefill_padding != 0:
# assert self.prefill_padding % self.num_heads == 0
# logger.warning_once(
# f"Padding num_heads from {self.num_heads} to "
# f"{self.prefill_padding} for BF16 sparse prefill kernel"
# )
# q_padded = q.new_empty((q.shape[0], self.prefill_padding, q.shape[2]))
# q_padded[:, : self.num_heads, :] = q
# q = q_padded
topk_indices = topk_indices.view(num_tokens, 1, -1)
output = flash_mla_sparse_fwd(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment