Commit 855cb148 authored by 王敏's avatar 王敏
Browse files

merge dev分支代码

parents 9135afe4 fe2e2705
......@@ -1056,9 +1056,21 @@ class CustomAllreduce {
size /= d;
auto bytes = size * sizeof(typename packed_t<T>::P);
int blocks = std::min(block_limit, (size + threads - 1) / threads);
// #define KL(ngpus, name) \
// name<T, ngpus><<<blocks, threads, 0, stream>>>(ptrs, sg_, self_sg_, output, \
// rank_, size, dev_curr_hdp_reg, world_size_) ;
#define KL(ngpus, name) \
name<T, ngpus><<<blocks, threads, 0, stream>>>(ptrs, sg_, self_sg_, output, \
rank_, size, dev_curr_hdp_reg, world_size_) ;
{ \
void* kernelArgs[] = { \
&ptrs, &sg_, &self_sg_, &output, &rank_, &size \
}; \
hipExtLaunchKernel( \
(void*)name<T, ngpus>, \
blocks, threads, \
kernelArgs, 0, \
stream, nullptr, stopEvent, 0 \
); \
}
#define REDUCE_CASE(ngpus) \
case ngpus: { \
......@@ -1066,7 +1078,7 @@ class CustomAllreduce {
KL(ngpus, cross_device_reduce_1stage_pcie); \
} else { \
if ((world_size_ <= 4 && bytes < 128 * 8192) || \
(world_size_ <= 8 && bytes < 8 * 8192)) { \
(world_size_ <= 8 && bytes < 8 * 8192)) { \
KL(ngpus, cross_device_reduce_1stage_pcie); \
} else { \
KL(ngpus, cross_device_reduce_2stage_pcie); \
......
......@@ -559,10 +559,10 @@ def get_version_add(sha: Optional[str] = None) -> str:
if sha is None:
sha = get_sha(vllm_root)
if (major, minor) >= ('2', '5'):
version = 'das.opt3.' + sha[:7]
version = 'das.opt4.' + sha[:7]
else:
if (major, minor) >= ('2', '5'):
version = 'das.opt3'
version = 'das.opt4'
# dtk version
......
......@@ -204,6 +204,8 @@ class Attention(nn.Module):
# shape does not match the query shape, so we optionally let the model
# definition specify the output tensor shape.
output_shape: Optional[torch.Size] = None,
query_nope: Optional[torch.Size] = None,
num_local_heads: Optional[int] = None,
q_ori: Optional[torch.Tensor] = None,
key_normed: Optional[torch.Tensor] = None,
positions: Optional[torch.Tensor] = None,
......@@ -270,7 +272,7 @@ class Attention(nn.Module):
query, key, value, output, self.layer_name)
else:
torch.ops.vllm.unified_attention_with_output(
query, key, value, output, self.layer_name, None, q_ori, key_normed, positions, weight, cos_sin_cache)
query, key, value, output, self.layer_name, None, query_nope, num_local_heads, q_ori, key_normed, positions, weight, cos_sin_cache)
return output.view(-1, hidden_size)
else:
if self.use_direct_call:
......@@ -511,6 +513,8 @@ def unified_attention_with_output(
output: torch.Tensor,
layer_name: str,
output_scale: Optional[torch.Tensor] = None,
query_nope: Optional[torch.Tensor] = None,
num_local_heads: Optional[int] = None,
q_ori: Optional[torch.Tensor] = None,
key_normed: Optional[torch.Tensor] = None,
positions: Optional[torch.Tensor] = None,
......@@ -542,6 +546,8 @@ def unified_attention_with_output(
attn_metadata,
output=output,
output_scale=output_scale,
query_nope=query_nope,
num_local_heads=num_local_heads,
q_ori=q_ori,
key_normed=key_normed,
positions=positions,
......@@ -560,6 +566,8 @@ def unified_attention_with_output_fake(
output: torch.Tensor,
layer_name: str,
output_scale: Optional[torch.Tensor] = None,
query_nope: Optional[torch.Tensor] = None,
num_local_heads: Optional[int] = None,
q_ori: Optional[torch.Tensor] = None,
key_normed: Optional[torch.Tensor] = None,
positions: Optional[torch.Tensor] = None,
......
......@@ -277,6 +277,60 @@ def flash_mla_with_kvcache_fp8(
)
return out, softmax_lse
def flash_mla_with_kvcache_fp8_with_cat(
q_nope: torch.Tensor,
q_pe: torch.Tensor,
k_cache: torch.Tensor,
block_table: torch.Tensor,
cache_seqlens: torch.Tensor,
head_dim_v: int,
tile_scheduler_metadata: torch.Tensor,
num_splits: torch.Tensor,
softmax_scale: Optional[float] = None,
causal: bool = False,
descale_q: Optional[torch.Tensor] = None,
descale_k: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Arguments:
q_nope: (batch_size, seq_len_q, num_heads_q, 512).
q_pe: (batch_size, seq_len_q, num_heads_q, 64).
k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
cache_seqlens: (batch_size), torch.int32.
head_dim_v: Head dimension of v.
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_metadata.
num_splits: (batch_size + 1), torch.int32, returned by get_mla_metadata.
softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
causal: bool. Whether to apply causal attention mask.
descale_q: (batch_size), torch.float32. Descaling factors for Q, used for fp8 quantization.
descale_k: (batch_size), torch.float32. Descaling factors for K, used for fp8 quantization.
Returns:
out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
"""
if softmax_scale is None:
softmax_scale = (q_nope.shape[-1] + q_pe.shape[-1]) ** (-0.5)
out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla_fp8_with_cat(
q_nope,
q_pe,
k_cache,
None,
head_dim_v,
cache_seqlens,
block_table,
softmax_scale,
causal,
tile_scheduler_metadata,
num_splits,
descale_q,
descale_k,
)
return out, softmax_lse
#
# TODO: Add fake functions
#
......
......@@ -271,10 +271,7 @@ class CustomAllreduce:
if envs.VLLM_CUSTOM_CACHE:
return self.all_reduce(input, registered=True)
else:
if not self.fully_connected:
return self.all_reduce(input, registered=False)
else:
return self.all_reduce(input, registered=True)
return self.all_reduce(input, registered=False)
else:
# If warm up, mimic the allocation pattern since custom
# allreduce is out-of-place.
......
......@@ -119,10 +119,10 @@ class P2pNcclEngine:
"remote_pp_size", 1)
self.enable_asymmetric_p2p = self.config.get_from_extra_config(
"enable_asymmetric_p2p", False)
if self.remote_tp_size % self.tp_size != 0:
logger.error(" the Prefill TP size must be less than or equal to the Decode TP size!!!!")
self.multp = int(self.remote_tp_size / self.tp_size)
if self.enable_asymmetric_p2p == True:
if self.remote_tp_size % self.tp_size != 0:
logger.error(" the Prefill TP size must be less than or equal to the Decode TP size!!!!")
self.multp = int(self.remote_tp_size / self.tp_size)
self.multiple_machines = self.config.get_from_extra_config(
"enable_multiple_machines", False)
port = int(self.config.kv_port) + port_offset
......@@ -742,7 +742,7 @@ class P2pNcclEngine:
"pd_pair_id": remote_address.pd_pair_id,
"comm_rank": rank
}
logger.info(f"""_send_sync_new:{data}""")
# logger.info(f"""_send_sync_new:{data}""")
sock.send(msgpack.dumps(data))
response = sock.recv()
......@@ -981,4 +981,4 @@ class P2pNcclEngine:
return ip, port
raise ValueError(
f"Request id {request_id} does not contain hostname and port")
\ No newline at end of file
f"Request id {request_id} does not contain hostname and port")
......@@ -198,8 +198,8 @@ if TYPE_CHECKING:
VLLM_PP_DEBUG: bool = False
VLLM_USE_V32_ENCODE: bool = False
VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT: bool = False
VLLM_USE_FUSED_CACHE_QUANT_BMM_MLA: bool = False
VLLM_USE_FUSED_RMS_ROPE: bool = False
VLLM_USE_MARLIN_W16A16_MOE:bool = False
VLLM_V1_USE_REDUCED_TOPK_TOPP_SAMPLER: bool = False
VLLM_USE_FUSED_FILL_RMS_CAT:bool = False
VLLM_ENABLE_DEEPEP_HT_DEEPGEMM: bool = True
......@@ -208,6 +208,13 @@ if TYPE_CHECKING:
VLLM_USE_FUSED_QA_KVA_GEMM: bool = False
VLLM_V1_FAST_TOKEN_ID_COPY: bool = False
VLLM_DISABLE_SHARED_EXPERTS_STREAM:bool = True
VLLM_W8A8_BACKEND: int = 3
VLLM_MOE_ROUTER_CAPTURE: bool = False
VLLM_MOE_ROUTER_CAPTURE_DIR: str = "/tmp"
VLLM_MOE_ROUTER_CAPTURE_RANK: int = -1
VLLM_MOE_ROUTER_CAPTURE_MAX_LAYERS: int = 0
VLLM_MOE_ROUTER_CAPTURE_NUM_TOKENS_GT: int = -1
VLLM_MOE_ROUTER_CAPTURE_NUM_TOKENS_LT: int = -1
def get_default_cache_root():
return os.getenv(
......@@ -1062,7 +1069,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
# If set, vLLM will use FLASH ATTN fp8 attention optimizations.
"VLLM_USE_FLASH_ATTN_FP8":
lambda: bool(int(os.getenv("VLLM_USE_FLASH_ATTN_FP8", "0"))),
lambda: bool(int(os.getenv("VLLM_USE_FLASH_ATTN_FP8", "1"))),
# If set, vLLM will use FLASH MLA attention optimizations.
"VLLM_USE_FLASH_MLA":
......@@ -1070,7 +1077,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
# If set, vLLM will use FLASH MLA fp8 attention optimizations.
"VLLM_USE_FLASH_MLA_FP8":
lambda: bool(int(os.getenv("VLLM_USE_FLASH_MLA_FP8", "0"))),
lambda: bool(int(os.getenv("VLLM_USE_FLASH_MLA_FP8", "1"))),
# flag to control vllm to use optimized kernels
"VLLM_USE_OPT_OP":
......@@ -1097,7 +1104,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
# flag to control vllm to use optimized kernels
"VLLM_CUSTOM_CACHE":
lambda: bool(int(os.environ.get("VLLM_CUSTOM_CACHE", "0"))),
lambda: bool(int(os.environ.get("VLLM_CUSTOM_CACHE", "1"))),
# flag to control vllm to use optimized kernels
"VLLM_CUSTOM_ALLREDUCE_SUPPORTED_WORLDSIZE_MAX":
......@@ -1223,7 +1230,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
# vLLM will split prefill and decode, not mix up
"VLLM_USE_PD_SPLIT":
lambda: (os.environ.get("VLLM_USE_PD_SPLIT", "False").lower() in
lambda: (os.environ.get("VLLM_USE_PD_SPLIT", "True").lower() in
("true", "1")),
# vLLM will sync to avoid pp vmfault
......@@ -1295,14 +1302,15 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda: (os.getenv('VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT', 'False').lower() in
("true", "1")),
# vllm will use fused rmsnorm + contiguous + rope(for dpsk-v3) + concat_and_cache_mla + q quant, control bmm + cat +mla (fp8)
"VLLM_USE_FUSED_CACHE_QUANT_BMM_MLA":
lambda: (os.getenv('VLLM_USE_FUSED_CACHE_QUANT_BMM_MLA', 'False').lower() in
("true", "1")),
# vLLM will use fused RMS + RoPE kernel
"VLLM_USE_FUSED_RMS_ROPE":
lambda: (os.environ.get("VLLM_USE_FUSED_RMS_ROPE", "False").lower() in
("true", "1")),
# vLLM will use Marlin W16A16 kernel for MoE experts
"VLLM_USE_MARLIN_W16A16_MOE":
lambda: (os.environ.get("VLLM_USE_MARLIN_W16A16_MOE", "False").lower() in
("true", "1")),
# vLLM will use lightop for dpsk mtp fill + rms*2 + cat
"VLLM_USE_FUSED_FILL_RMS_CAT":
lambda: (os.environ.get("VLLM_USE_FUSED_FILL_RMS_CAT", "False").lower() in
......@@ -1341,6 +1349,34 @@ environment_variables: dict[str, Callable[[], Any]] = {
int(os.getenv("VLLM_DISABLE_SHARED_EXPERTS_STREAM", "1"))
),
# W8A8 GEMM backend selection for vLLM quantized models.
# lightop/triton: 1
# cutlass: 2 (will remove in the future)
# blaslt: 3 (default)
# rocblas: others
"VLLM_W8A8_BACKEND": lambda: int(os.getenv("VLLM_W8A8_BACKEND", "3")),
# Capture MoE router logits for debugging/analysis.
"VLLM_MOE_ROUTER_CAPTURE":
lambda: (os.getenv("VLLM_MOE_ROUTER_CAPTURE", "0").lower() in ("true", "1")),
# Output directory for MoE router capture dumps.
"VLLM_MOE_ROUTER_CAPTURE_DIR":
lambda: os.environ.get(
"VLLM_MOE_ROUTER_CAPTURE_DIR",
"/tmp",
),
# Capture only the specified rank; set to -1 to capture all ranks.
"VLLM_MOE_ROUTER_CAPTURE_RANK":
lambda: int(os.environ.get("VLLM_MOE_ROUTER_CAPTURE_RANK", "-1")),
# Max number of MoE layers to record per process (0 = unlimited).
"VLLM_MOE_ROUTER_CAPTURE_MAX_LAYERS":
lambda: int(os.environ.get("VLLM_MOE_ROUTER_CAPTURE_MAX_LAYERS", "0")),
# Only capture when num_tokens > N (negative disables).
"VLLM_MOE_ROUTER_CAPTURE_NUM_TOKENS_GT":
lambda: int(os.environ.get("VLLM_MOE_ROUTER_CAPTURE_NUM_TOKENS_GT", "-1")),
# Only capture when num_tokens < N (0 disables).
"VLLM_MOE_ROUTER_CAPTURE_NUM_TOKENS_LT":
lambda: int(os.environ.get("VLLM_MOE_ROUTER_CAPTURE_NUM_TOKENS_LT", "-1")),
}
# --8<-- [end:env-vars-definition]
......@@ -1405,6 +1441,7 @@ def compute_hash() -> str:
"VLLM_DP_SIZE",
"VLLM_USE_STANDALONE_COMPILE",
"VLLM_FUSED_MOE_CHUNK_SIZE",
"VLLM_W8A8_BACKEND",
]
for key in environment_variables_to_hash:
if key in environment_variables:
......
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2
},
"2": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2
},
"256": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 2
},
"1024": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"3072": {
"BLOCK_SIZE_M": 256,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
}
}
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
},
"2": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"num_ldmatrixes": 1
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"num_ldmatrixes": 1
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"num_ldmatrixes": 1
},
"48": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 16,
"num_warps": 2,
"num_stages": 2,
"num_ldmatrixes": 1
},
"96": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"num_ldmatrixes": 1
},
"256": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
},
"512": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
}
}
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"2": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"512": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
}
}
......@@ -1225,14 +1225,14 @@ def vllm_topk_softmax(topk_weights: torch.Tensor, topk_indices: torch.Tensor,
token_expert_indices: torch.Tensor,
gating_output: torch.Tensor,
renormalize: bool) -> tuple[torch.Tensor, ...]:
if envs.VLLM_USE_TOPK_RENORM:
if envs.VLLM_USE_TOPK_RENORM and renormalize is True:
from lightop import op as op
op.topk_softmax(
topk_weights,
topk_indices,
token_expert_indices,
gating_output,
True,
renormalize,
)
else:
ops.topk_softmax(
......@@ -1681,93 +1681,88 @@ def fused_experts_impl(
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
M = min(num_tokens, CHUNK_SIZE)
# Optional fast path: use Marlin W16A16 fused MoE implementation when
# explicitly requested. When weights are pre-packed in the post-load hook,
# w1/w2 are already in Marlin layout and we can avoid first-run packing
# peaks during KV cache profiling.
if envs.VLLM_USE_MARLIN_W16A16_MOE and not use_nn_moe:
try:
from vllm.model_executor.layers.fused_moe.fuse_moe_w16a16_marlin import ( # noqa: E501
fused_experts_impl_w16a16_marlin)
except Exception:
fused_experts_impl_w16a16_marlin = None # type: ignore
if fused_experts_impl_w16a16_marlin is not None:
K = hidden_states.size(1)
def _is_marlin_w16a16_packed(w1: torch.Tensor,
w2: torch.Tensor) -> bool:
if w1.dim() != 3 or w2.dim() != 3:
return False
if w1.size(0) != w2.size(0):
return False
k_div16 = w1.size(1)
if k_div16 * 16 != K:
return False
if w1.size(2) % 16 != 0:
return False
twoN = w1.size(2) // 16
if twoN % 2 != 0:
return False
N = twoN // 2
if w2.size(2) != K * 16:
return False
if w2.size(1) * 16 != N:
return False
return True
if (getattr(w1, "marlin_w16a16_packed", False)
or getattr(w2, "marlin_w16a16_packed", False)
or _is_marlin_w16a16_packed(w1, w2)):
E = w1.size(0)
if global_num_experts == -1:
global_num_experts = E
twoN = w1.size(2) // 16
if envs.VLLM_USE_GLOBAL_CACHE13:
cache13 = get_moe_cache(top_k_num,
twoN,
K,
device=hidden_states.device,
dtype=hidden_states.dtype)
else:
cache13 = torch.empty(M * top_k_num * max(twoN, K),
device=hidden_states.device,
dtype=hidden_states.dtype)
return fused_experts_impl_w16a16_marlin(
hidden_states=hidden_states,
w1_marlin=w1,
w2_marlin=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
cache13=cache13,
inplace=inplace,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
use_nn_moe=False,
routed_scaling_factor=routed_scaling_factor,
shared_output=shared_output,
# Optional fast path: use Marlin W16A16 fused MoE implementation when the
# expert weights are already packed in Marlin layout.
if not use_nn_moe:
K = hidden_states.size(1)
def _is_marlin_w16a16_packed(w1: torch.Tensor,
w2: torch.Tensor) -> bool:
if w1.dim() != 3 or w2.dim() != 3:
return False
if w1.size(0) != w2.size(0):
return False
k_div16 = w1.size(1)
if k_div16 * 16 != K:
return False
if w1.size(2) % 16 != 0:
return False
twoN = w1.size(2) // 16
if twoN % 2 != 0:
return False
N = twoN // 2
if w2.size(2) != K * 16:
return False
if w2.size(1) * 16 != N:
return False
return True
is_packed = (getattr(w1, "marlin_w16a16_packed", False)
or getattr(w2, "marlin_w16a16_packed", False)
or _is_marlin_w16a16_packed(w1, w2))
if is_packed:
try:
from vllm.model_executor.layers.fused_moe.fuse_moe_w16a16_marlin import ( # noqa: E501
fused_experts_impl_w16a16_marlin)
except Exception:
fused_experts_impl_w16a16_marlin = None # type: ignore
if fused_experts_impl_w16a16_marlin is None:
raise RuntimeError(
"Marlin W16A16 MoE weights are packed, but the Marlin kernel is unavailable. "
"Ensure lightop is installed and VLLM_USE_LIGHTOP=1."
)
# No fallback packing: require pre-packed weights when Marlin W16A16
# MoE is enabled. If weights are still in the original layout, fail
# fast to avoid packing-induced peak memory and unpredictable
# warmup/profiling behavior.
if (w1.dim() == 3 and w2.dim() == 3 and w1.size(0) == w2.size(0)
and w2.size(1) == K):
twoN = w1.size(1)
N = w2.size(2)
if (twoN == 2 * N and (K % 32 == 0) and (N % 16 == 0)
and (twoN % 32 == 0)):
raise RuntimeError(
"VLLM_USE_MARLIN_W16A16_MOE is enabled, but MoE weights "
"are not pre-packed in Marlin layout. Pre-pack weights "
"during the post-load hook or disable "
"VLLM_USE_MARLIN_W16A16_MOE."
)
if activation != "silu":
raise RuntimeError(
"Marlin W16A16 MoE only supports activation='silu'.")
if apply_router_weight_on_input:
raise RuntimeError(
"Marlin W16A16 MoE does not support apply_router_weight_on_input=True."
)
E = w1.size(0)
if global_num_experts == -1:
global_num_experts = E
twoN = w1.size(2) // 16
if envs.VLLM_USE_GLOBAL_CACHE13:
cache13 = get_moe_cache(top_k_num,
twoN,
K,
device=hidden_states.device,
dtype=hidden_states.dtype)
else:
cache13 = torch.empty(M * top_k_num * max(twoN, K),
device=hidden_states.device,
dtype=hidden_states.dtype)
return fused_experts_impl_w16a16_marlin(
hidden_states=hidden_states,
w1_marlin=w1,
w2_marlin=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
cache13=cache13,
inplace=inplace,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
use_nn_moe=False,
routed_scaling_factor=routed_scaling_factor,
shared_output=shared_output,
)
# Non-Marlin paths need the original weight shapes.
if use_nn_moe:
......@@ -1791,7 +1786,7 @@ def fused_experts_impl(
device=hidden_states.device,
dtype=hidden_states.dtype)
if use_int8_w8a8 is True:
if use_int8_w8a8 or use_fp8_w8a8:
return fused_experts_impl_int8(hidden_states=hidden_states,
w1=w1,
w2=w2,
......@@ -1801,8 +1796,8 @@ def fused_experts_impl(
inplace=inplace,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
use_fp8_w8a8=False,
use_int8_w8a8=True,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=False,
use_int4_w4a16=False,
per_channel_quant=per_channel_quant,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
import os
import importlib
......@@ -76,6 +77,66 @@ else:
logger = init_logger(__name__)
_MARLIN_W16A16_MOE_PROBE_BATCH_SIZES: tuple[int, ...] = (1, 128)
@functools.lru_cache
def _is_marlin_w16a16_moe_supported(
E: int,
N: int,
K: int,
top_k: int,
dtype: torch.dtype,
) -> bool:
"""Return True if lightop reports Marlin W16A16 MoE is supported.
This is a best-effort probe used to decide whether we can safely pre-pack
weights into Marlin layout (which would otherwise prevent fallback).
"""
if not (current_platform.is_cuda_alike() and torch.cuda.is_available()):
return False
if dtype not in (torch.float16, torch.bfloat16):
return False
if K % 32 != 0 or N % 16 != 0:
return False
if E <= 0 or N <= 0 or K <= 0 or top_k <= 0:
return False
if not envs.VLLM_USE_LIGHTOP:
return False
try:
from lightop import get_moe_cuda_marlin_config_w16a16
props = torch.cuda.get_device_properties(torch.cuda.current_device())
arch_name = getattr(props, "gcnArchName", None)
if isinstance(arch_name, str) and arch_name:
arch_name = arch_name.split(":")[0]
else:
arch_name = getattr(props, "name", None)
if not isinstance(arch_name, str) or not arch_name:
return False
arch_cu = props.multi_processor_count
twoN = 2 * N
for bs in _MARLIN_W16A16_MOE_PROBE_BATCH_SIZES:
_, _, status = get_moe_cuda_marlin_config_w16a16(
E,
bs,
twoN,
K,
K,
N,
top_k,
arch_name,
arch_cu,
dtype,
)
if not status:
return False
return True
except Exception:
return False
# Global auxilary stream for running operations in background streams.
# We have single global auxilary stream to avoid an explosion of streams
# for every layer (and make profiling look sane).
......@@ -84,6 +145,7 @@ logger = init_logger(__name__)
# - MoE shared_expert overlap with router
_aux_stream: torch.cuda.Stream | None = None
def aux_stream() -> torch.cuda.Stream | None:
"""
Ensures aux_stream is initialized only once
......@@ -407,12 +469,13 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
super().process_weights_after_loading(layer)
# If Marlin W16A16 MoE is enabled, pre-pack weights once during the
# If Marlin W16A16 MoE is supported, pre-pack weights once during the
# post-load hook and replace parameters with the packed layout.
#
# This avoids first-run packing peaks during KV cache profiling and
# keeps only one copy of weights resident on GPU in steady state.
if (envs.VLLM_USE_MARLIN_W16A16_MOE and current_platform.is_cuda_alike()
if (getattr(layer, "_marlin_w16a16_moe_enabled", False)
and current_platform.is_cuda_alike()
and not getattr(layer, "use_nn_moe", False)
and not getattr(layer, "_marlin_w16a16_moe_packed", False)):
w1 = layer.w13_weight
......@@ -421,12 +484,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
and w1.dtype in (torch.float16, torch.bfloat16)
and w2.dtype in (torch.float16, torch.bfloat16)):
try:
from vllm.model_executor.layers.fused_moe.fuse_moe_w16a16_marlin import ( # noqa: E501
use_lightop as _use_lightop)
if not _use_lightop:
raise RuntimeError(
"Marlin W16A16 MoE kernel is disabled")
if w1.dim() != 3 or w2.dim() != 3 or w1.size(0) != w2.size(
0):
raise RuntimeError("Unexpected MoE weight shapes")
......@@ -992,9 +1049,25 @@ class FusedMoE(torch.nn.Module):
if quant_config is None:
# Not considering quant for now, temporarily
self.use_nn_moe = int(os.environ.get('MOE_NN', 1)) == 1
moe_in_dtype = model_dtype
self._marlin_w16a16_moe_enabled = (
params_dtype == moe_in_dtype and self.activation == "silu"
and not self.apply_router_weight_on_input
and _is_marlin_w16a16_moe_supported(
E=self.local_num_experts,
N=self.intermediate_size_per_partition,
K=self.hidden_size,
top_k=self.top_k,
dtype=moe_in_dtype,
))
self.use_nn_moe = int(os.environ.get("MOE_NN", 1)) == 1
# Marlin W16A16 MoE requires the non-NN weight layout.
if self._marlin_w16a16_moe_enabled:
self.use_nn_moe = False
else:
self.use_nn_moe = False
self._marlin_w16a16_moe_enabled = False
moe_quant_params = {
"num_experts": self.local_num_experts,
......
"""
Utilities for capturing MoE router distributions from real workloads.
This is intentionally lightweight and gated behind env vars so it has zero
runtime impact unless explicitly enabled.
Env vars (defaults from vllm.envs):
- VLLM_MOE_ROUTER_CAPTURE=0/1: enable capture (default: 0).
- VLLM_MOE_ROUTER_CAPTURE_DIR=/path: output directory for per-process dumps
(default: /tmp).
- VLLM_MOE_ROUTER_CAPTURE_RANK=N: only capture on the given torch.distributed
rank (default: -1; set to -1 to capture all ranks).
- VLLM_MOE_ROUTER_CAPTURE_MAX_LAYERS=N: max number of layers to record per
process (default: 0; 0 = unlimited).
- VLLM_MOE_ROUTER_CAPTURE_NUM_TOKENS_GT=A: only record calls where router_logits
has num_tokens > A (default: -1; <0 = disabled).
- VLLM_MOE_ROUTER_CAPTURE_NUM_TOKENS_LT=B: only record calls where router_logits
has num_tokens < B (default: -1; 0 = disabled).
Output format:
- A single `.pt` per captured num_tokens (and per rank if torch.distributed is
initialized).
- Payload includes `layers_by_num_tokens: dict[str, dict[layer_name, layer_state]]`.
- A convenience `layers` field is also included (same as
`layers_by_num_tokens[str(num_tokens)]`) for easy loading.
- For each captured MoE layer, stores a list of 2D tensors
`router_logits_chunks: list[Tensor[num_tokens_i, num_experts]]` on CPU,
typically in fp16 for space efficiency.
"""
from __future__ import annotations
import atexit
import inspect
import os
import socket
import threading
import time
from dataclasses import dataclass
from typing import Optional
import torch
import vllm.envs as envs
_DEFAULT_SKIP_STACK_FUNCS = ("profile_run", "_dummy_run",
"determine_available_memory")
@dataclass(frozen=True)
class RouterCaptureConfig:
enabled: bool = False
out_dir: str = "/tmp"
skip_profile: bool = True
skip_stack_funcs: tuple[str, ...] = _DEFAULT_SKIP_STACK_FUNCS
only_rank: Optional[int] = 0
max_layers: int = 0
num_tokens_gt: Optional[int] = None
num_tokens_lt: Optional[int] = None
@staticmethod
def from_env() -> "RouterCaptureConfig":
enabled = envs.VLLM_MOE_ROUTER_CAPTURE
out_dir = envs.VLLM_MOE_ROUTER_CAPTURE_DIR
skip_profile = True
skip_stack_funcs = _DEFAULT_SKIP_STACK_FUNCS
only_rank: Optional[int] = None
if envs.VLLM_MOE_ROUTER_CAPTURE_RANK >= 0:
only_rank = envs.VLLM_MOE_ROUTER_CAPTURE_RANK
max_layers = envs.VLLM_MOE_ROUTER_CAPTURE_MAX_LAYERS
num_tokens_gt_opt = (envs.VLLM_MOE_ROUTER_CAPTURE_NUM_TOKENS_GT
if envs.VLLM_MOE_ROUTER_CAPTURE_NUM_TOKENS_GT >= 0
else None)
num_tokens_lt_opt = (envs.VLLM_MOE_ROUTER_CAPTURE_NUM_TOKENS_LT
if envs.VLLM_MOE_ROUTER_CAPTURE_NUM_TOKENS_LT > 0
else None)
# Per-size mode requires an explicit token-count filter to avoid
# unbounded captures by default.
if num_tokens_gt_opt is None and num_tokens_lt_opt is None:
enabled = False
if (num_tokens_gt_opt is not None and num_tokens_lt_opt is not None
and num_tokens_gt_opt >= num_tokens_lt_opt):
enabled = False
return RouterCaptureConfig(enabled=enabled,
out_dir=out_dir,
skip_profile=skip_profile,
skip_stack_funcs=skip_stack_funcs,
only_rank=only_rank,
max_layers=max_layers,
num_tokens_gt=num_tokens_gt_opt,
num_tokens_lt=num_tokens_lt_opt)
def _in_profile_run(skip_stack_funcs: tuple[str, ...]) -> bool:
"""
Best-effort detection for vLLM startup profiling/warmup runs.
Startup warmups often execute MoE kernels with synthetic shapes. When
enabled, skip captures from these stacks so the first capture comes from a
real request.
"""
if not skip_stack_funcs:
return False
frame = inspect.currentframe()
try:
while frame is not None:
name = frame.f_code.co_name
if name in skip_stack_funcs:
return True
frame = frame.f_back
finally:
# Avoid reference cycles.
del frame
return False
class _RouterCapture:
def __init__(self, cfg: RouterCaptureConfig) -> None:
self.cfg = cfg
# Bucket captures by token count.
self._layers_by_num_tokens: dict[int, dict[str, dict[str, object]]] = {}
self._layer_names: set[str] = set()
self._completed_num_tokens: set[int] = set()
self._lock = threading.Lock()
self._flush_counter = 0
self._pid = os.getpid()
self._host = socket.gethostname()
self._start_time = time.time()
os.makedirs(cfg.out_dir, exist_ok=True)
atexit.register(self.flush)
def _bucket_for_num_tokens(self, num_tokens: int) -> Optional[int]:
"""Return the per-size bucket key for this record call, or None if filtered."""
if self.cfg.num_tokens_gt is None and self.cfg.num_tokens_lt is None:
return None
if self.cfg.num_tokens_gt is not None:
if int(num_tokens) <= int(self.cfg.num_tokens_gt):
return None
if self.cfg.num_tokens_lt is not None:
if int(num_tokens) >= int(self.cfg.num_tokens_lt):
return None
bucket_num_tokens = int(num_tokens)
if bucket_num_tokens != 0 and bucket_num_tokens in self._completed_num_tokens:
return None
return bucket_num_tokens
def _snapshot_layers_by_num_tokens(
self,
layers_by_num_tokens: dict[int, dict[str, dict[str, object]]],
) -> dict[int, dict[str, dict[str, object]]]:
snapshot: dict[int, dict[str, dict[str, object]]] = {}
for num_tokens, bucket in layers_by_num_tokens.items():
bucket_snapshot: dict[str, dict[str, object]] = {}
for layer_name, state in bucket.items():
chunks = state.get("router_logits_chunks", [])
bucket_snapshot[layer_name] = {
"num_experts": int(state.get("num_experts", 0)),
"num_tokens": int(state.get("num_tokens", 0)),
"router_logits_chunks": list(chunks),
}
snapshot[int(num_tokens)] = bucket_snapshot
return snapshot
@torch.no_grad()
def record(self, layer_name: str, router_logits: torch.Tensor,
top_k: int) -> None:
if self.cfg.skip_profile and _in_profile_run(self.cfg.skip_stack_funcs):
return
if self.cfg.only_rank is not None:
rank = _get_rank()
if rank is not None and rank != self.cfg.only_rank:
return
if router_logits.dim() != 2:
return
num_tokens, num_experts = router_logits.shape
if num_tokens == 0 or num_experts == 0:
return
bucket_num_tokens = self._bucket_for_num_tokens(int(num_tokens))
if bucket_num_tokens is None:
return
# Limit the number of recorded layers to avoid unbounded dumps.
if layer_name not in self._layer_names:
if self.cfg.max_layers != 0 and len(self._layer_names) >= self.cfg.max_layers:
return
self._layer_names.add(layer_name)
# Store on CPU to avoid consuming GPU memory during long runs.
# fp16 is typically sufficient because we primarily care about
# distribution and relative ordering (top-k), not exact values.
router_logits_cpu = router_logits.detach()
if router_logits_cpu.is_cuda:
router_logits_cpu = router_logits_cpu.to(device="cpu",
dtype=torch.float16)
else:
router_logits_cpu = router_logits_cpu.to(dtype=torch.float16)
bucket_snapshot: Optional[dict[str, dict[str, object]]] = None
should_flush = False
with self._lock:
bucket = self._layers_by_num_tokens.setdefault(bucket_num_tokens, {})
if layer_name in bucket:
return
bucket[layer_name] = {
"num_experts": int(num_experts),
"num_tokens": int(num_tokens),
"router_logits_chunks": [router_logits_cpu],
}
if self.cfg.max_layers != 0 and len(bucket) >= int(self.cfg.max_layers):
should_flush = True
bucket_snapshot = self._snapshot_layers_by_num_tokens(
{int(bucket_num_tokens): bucket})[int(bucket_num_tokens)]
self._completed_num_tokens.add(int(bucket_num_tokens))
self._layers_by_num_tokens.pop(int(bucket_num_tokens), None)
if should_flush and bucket_snapshot is not None:
self._flush_payload(
layers_by_num_tokens={int(bucket_num_tokens): bucket_snapshot},
file_tag=f"nt{int(bucket_num_tokens)}",
)
def _flush_payload(
self,
*,
layers_by_num_tokens: dict[int, dict[str, dict[str, object]]],
file_tag: Optional[str] = None,
) -> Optional[str]:
if not self.cfg.enabled:
return None
if self.cfg.only_rank is not None:
rank = _get_rank()
if rank is not None and rank != self.cfg.only_rank:
return None
rank = _get_rank()
now = time.time()
ts = time.strftime("%Y%m%d_%H%M%S", time.localtime(now))
ts_us = int(now * 1_000_000)
with self._lock:
flush_idx = self._flush_counter
self._flush_counter += 1
rank_str = f"rank{rank}" if rank is not None else "rankNA"
tag = f"{file_tag}_" if file_tag else ""
out_path = os.path.join(
self.cfg.out_dir,
f"moe_router_stats_{tag}{ts_us}_{self._host}_{rank_str}_pid{self._pid}_flush{flush_idx}.pt",
)
layers_by_num_tokens_out: dict[str, object] = {}
for num_tokens, bucket in layers_by_num_tokens.items():
bucket_out: dict[str, object] = {}
for layer_name, state in bucket.items():
bucket_out[layer_name] = {
"num_experts": int(state["num_experts"]),
"num_tokens": int(state["num_tokens"]),
"router_logits_chunks":
state["router_logits_chunks"], # type: ignore[typeddict-item]
}
layers_by_num_tokens_out[str(int(num_tokens))] = bucket_out
payload: dict[str, object] = {
"meta": {
"timestamp": ts,
"timestamp_us": ts_us,
"flush_index": int(flush_idx),
"host": self._host,
"pid": self._pid,
"rank": rank,
"wall_time_s": float(now - self._start_time),
},
"layers_by_num_tokens": layers_by_num_tokens_out,
}
# Backward-compatible convenience field when there is a single bucket.
if len(layers_by_num_tokens) == 1:
(only_bucket_key, ) = layers_by_num_tokens.keys()
payload["layers"] = layers_by_num_tokens_out[str(int(only_bucket_key))]
try:
torch.save(payload, out_path)
except Exception:
return None
return out_path
def flush(self) -> Optional[str]:
with self._lock:
if not self._layers_by_num_tokens:
return None
snapshot = self._snapshot_layers_by_num_tokens(self._layers_by_num_tokens)
return self._flush_payload(layers_by_num_tokens=snapshot)
def reset(self) -> None:
with self._lock:
self._layers_by_num_tokens.clear()
self._layer_names.clear()
self._completed_num_tokens.clear()
self._start_time = time.time()
_CAPTURE: Optional[_RouterCapture] = None
_CAPTURE_DISABLED: bool = False
def _disable_global_capture() -> None:
global _CAPTURE, _CAPTURE_DISABLED
_CAPTURE = None
_CAPTURE_DISABLED = True
def _get_rank() -> Optional[int]:
if torch.distributed.is_available() and torch.distributed.is_initialized():
try:
return torch.distributed.get_rank()
except Exception:
return None
return None
def _get_capture() -> Optional[_RouterCapture]:
global _CAPTURE, _CAPTURE_DISABLED
if _CAPTURE_DISABLED:
return None
if _CAPTURE is not None:
return _CAPTURE
cfg = RouterCaptureConfig.from_env()
if not cfg.enabled:
_disable_global_capture()
return None
if cfg.only_rank is not None:
rank = _get_rank()
if rank is not None and rank != cfg.only_rank:
_disable_global_capture()
return None
_CAPTURE = _RouterCapture(cfg)
return _CAPTURE
@torch.no_grad()
def maybe_record_router_logits(*, layer_name: str, router_logits: torch.Tensor,
top_k: int) -> None:
capture = _get_capture()
if capture is None:
return
capture.record(layer_name=layer_name, router_logits=router_logits, top_k=top_k)
def maybe_flush_router_capture(*, reset: bool = False) -> Optional[str]:
"""Flush capture buffers to disk without exiting the process."""
capture = _get_capture()
if capture is None:
return None
out_path = capture.flush()
if out_path is not None and reset:
capture.reset()
return out_path
......@@ -20,6 +20,7 @@ from vllm.model_executor.parameter import (BasevLLMParameter,
PerTensorScaleParameter)
from vllm.utils import W8a8GetCacheJSON
from vllm import _custom_ops as ops
import vllm.envs as envs
logger = init_logger(__name__)
......@@ -31,8 +32,8 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
input_symmetric: bool):
self.strategy = strategy
self.is_static_input_scheme = is_static_input_scheme
self.tritonsingleton= W8a8GetCacheJSON()
self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '1'))
self.tritonsingleton = W8a8GetCacheJSON()
self.w8a8_strategy = envs.VLLM_W8A8_BACKEND
self.input_symmetric = input_symmetric
@classmethod
......
......@@ -331,8 +331,12 @@ class Fp8LinearMethod(LinearMethodBase):
weight = layer.weight.data
weight_scale_inv = layer.weight_scale_inv.data
weight = self._maybe_pad_weight(weight)
if envs.VLLM_W8A8_BACKEND == 3:
weight = weight.T.contiguous()
weight_scale_inv = weight_scale_inv.T.contiguous()
else:
weight = self._maybe_pad_weight(weight)
# Torch.compile cannot use Parameter subclasses.
layer.weight = Parameter(weight, requires_grad=False)
layer.weight_scale_inv = Parameter(weight_scale_inv,
......
......@@ -92,8 +92,8 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase):
def __init__(self, quantization_config: SlimQuantW4A8Int8Config):
self.quantization_config = quantization_config
self.tritonsingleton= W8a8GetCacheJSON()
self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '1'))
self.tritonsingleton = W8a8GetCacheJSON()
self.w8a8_strategy = envs.VLLM_W8A8_BACKEND
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
n=layer.weight.shape[0]
......
......@@ -6,6 +6,8 @@ import functools
import json
import os
from typing import Any, Callable, Optional, Union, List
from lmslim import quant_ops
from lmslim.quantize.quant_ops import BlockSize
import torch
......@@ -19,6 +21,10 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils import cdiv, direct_register_custom_op, has_deep_gemm
try:
from lmslim.layers.gemm.fp8_utils import per_token_group_quant_fp8,w8a8_block_fp8_matmul
except Exception:
print("INFO: Please updata lmslim if you want to use fp8_utils.\n")
logger = init_logger(__name__)
......@@ -83,7 +89,7 @@ if current_platform.is_rocm():
def dispatch_w8a8_blockscale_func(
use_cutlass: bool, use_aiter_and_is_supported: bool
use_cutlass: bool, use_aiter_and_is_supported: bool, use_blaslt: bool
) -> Callable[[
torch.Tensor,
torch.Tensor,
......@@ -96,6 +102,9 @@ def dispatch_w8a8_blockscale_func(
return cutlass_scaled_mm
if (use_aiter_and_is_supported):
return torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale
if use_blaslt:
return hipblaslt_w8a8_block_fp8_matmul
return w8a8_block_fp8_matmul
......@@ -127,7 +136,11 @@ def apply_w8a8_block_fp8_linear(
assert input_scale is None
# View input as 2D matrix for fp8 methods
input_2d = input.view(-1, input.shape[-1])
output_shape = [*input.shape[:-1], weight.shape[0]]
output_shape = []
if envs.VLLM_W8A8_BACKEND == 3:
output_shape = [*input.shape[:-1], weight.shape[-1]]
else:
output_shape = [*input.shape[:-1], weight.shape[0]]
output_dtype = input.dtype
if should_use_deepgemm(output_dtype, weight):
......@@ -166,9 +179,12 @@ def apply_w8a8_block_fp8_linear(
weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0)
else:
use_cutlass = False
use_blaslt = False
if envs.VLLM_W8A8_BACKEND == 3:
use_blaslt = True
w8a8_blockscale_func = dispatch_w8a8_blockscale_func(
use_cutlass, use_aiter_and_is_supported)
use_cutlass, use_aiter_and_is_supported, use_blaslt)
if use_cutlass:
q_input, x_scale = per_token_group_quant_fp8(
input_2d, block_size[1], column_major_scales=use_cutlass)
......@@ -197,7 +213,11 @@ def apply_w8a8_block_fp8_linear_fake(
cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED,
use_aiter_and_is_supported: bool = False,
) -> torch.Tensor:
output_shape = [*input.shape[:-1], weight.shape[0]]
output_shape = []
if envs.VLLM_W8A8_BACKEND == 3:
output_shape = [*input.shape[:-1], weight.shape[-1]]
else:
output_shape = [*input.shape[:-1], weight.shape[0]]
return torch.empty(output_shape, dtype=input.dtype, device=input.device)
......@@ -240,333 +260,9 @@ def block_quant_to_tensor_quant(
return x_q_tensor, scale
@triton.jit
def _per_token_group_quant_fp8(
# Pointers to inputs and output
y_ptr,
y_q_ptr,
y_s_ptr,
group_size,
# Num columns of y
y_num_columns,
y_row_stride,
# Avoid to divide zero
eps,
# Information for float8
fp8_min,
fp8_max,
# Meta-parameters
BLOCK: tl.constexpr,
):
"""A Triton-accelerated function to perform per-token-group
quantization on a tensor.
This function converts the tensor values into float8 values.
"""
groups_per_row = y_num_columns // group_size
# Map the program id to the row of X and Y it should compute.
g_id = tl.program_id(0)
row = g_id // groups_per_row
row_g_id = g_id % groups_per_row
# Ensure offset calculations use int64 to prevent overflow
y_ptr_offset = (row.to(tl.int64) * y_row_stride) + (row_g_id.to(tl.int64) *
group_size)
y_ptr += y_ptr_offset
y_q_ptr_offset = g_id.to(tl.int64) * group_size
y_q_ptr += y_q_ptr_offset
y_s_ptr += g_id
cols = tl.arange(0, BLOCK) # N <= BLOCK
mask = cols < group_size
y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
# Quant
_absmax = tl.maximum(tl.max(tl.abs(y)), eps)
y_s = _absmax / fp8_max
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
tl.store(y_q_ptr + cols, y_q, mask=mask)
tl.store(y_s_ptr, y_s)
@triton.jit
def _per_token_group_quant_fp8_colmajor(
# Pointers to inputs and output
y_ptr,
y_q_ptr,
y_s_ptr,
group_size,
# Num columns of y
y_num_columns,
y_row_stride,
# Stride from one column to the next of y_s
y_s_col_stride,
# Avoid to divide zero
eps,
# Information for float8
fp8_min,
fp8_max,
# Meta-parameters
BLOCK: tl.constexpr,
):
"""A Triton-accelerated function to perform per-token-group
quantization on a tensor.
This function converts the tensor values into float8 values.
"""
groups_per_row = y_num_columns // group_size
# Map the program id to the row of X and Y it should compute.
g_id = tl.program_id(0)
row = g_id // groups_per_row
row_g_id = g_id % groups_per_row
# Ensure offset calculations use int64 to prevent overflow
y_ptr_offset = (row.to(tl.int64) * y_row_stride) + (row_g_id.to(tl.int64) *
group_size)
y_ptr += y_ptr_offset
y_q_ptr_offset = g_id.to(tl.int64) * group_size
y_q_ptr += y_q_ptr_offset
# Convert g_id the flattened block coordinate to 2D so we can index
# into the output y_scales matrix
blocks_per_row = y_num_columns // group_size
scale_col = g_id % blocks_per_row
scale_row = g_id // blocks_per_row
# Ensure offset calculation uses int64 for y_s_ptr
y_s_ptr_offset = (scale_col.to(tl.int64) * y_s_col_stride) + scale_row.to(
tl.int64)
y_s_ptr += y_s_ptr_offset
cols = tl.arange(0, BLOCK) # group_size <= BLOCK
mask = cols < group_size
y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
# Quant
_absmax = tl.maximum(tl.max(tl.abs(y)), eps)
y_s = _absmax / fp8_max
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
tl.store(y_q_ptr + cols, y_q, mask=mask)
tl.store(y_s_ptr, y_s)
def per_token_group_quant_fp8(
x: torch.Tensor,
group_size: int,
eps: float = 1e-10,
dtype: Optional[torch.dtype] = None,
column_major_scales: bool = False,
out_q: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Function to perform per-token-group quantization on an input tensor `x`.
It converts the tensor values into signed float8 values and returns the
quantized tensor along with the scaling factor used for quantization.
Args:
x: The input tensor with ndim >= 2.
group_size: The group size used for quantization.
eps: The minimum to avoid dividing zero.
dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn`
is supported for now.
column_major_scales: Outputs scales in column major.
out_q: Optional output tensor. If not provided, function will create.
Returns:
tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
scaling factor for quantization.
"""
dtype = current_platform.fp8_dtype() if dtype is None else dtype
assert (x.shape[-1] % group_size == 0), (
f"the last dimension of `x` {x.shape[-1]} must be divisible "
f"by `group_size` {group_size}")
assert x.stride(-1) == 1, "`x` groups must be contiguous"
finfo = torch.finfo(dtype)
fp8_min = finfo.min
fp8_max = finfo.max
assert out_q is None or out_q.shape == x.shape
x_q = out_q
if x_q is None:
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
M = x.numel() // group_size
N = group_size
if column_major_scales:
shape = (x.shape[-1] // group_size, ) + x.shape[:-1]
x_s = torch.empty(shape, device=x.device,
dtype=torch.float32).permute(-1, -2)
else:
shape = x.shape[:-1] + (x.shape[-1] // group_size, )
x_s = torch.empty(shape, device=x.device, dtype=torch.float32)
BLOCK = triton.next_power_of_2(N)
# heuristics for number of warps
num_warps = min(max(BLOCK // 256, 1), 8)
num_stages = 1
if column_major_scales:
_per_token_group_quant_fp8_colmajor[(M, )](
x,
x_q,
x_s,
group_size,
x.shape[1],
x.stride(0),
x_s.stride(1),
eps,
fp8_min=fp8_min,
fp8_max=fp8_max,
BLOCK=BLOCK,
num_warps=num_warps,
num_stages=num_stages,
)
else:
_per_token_group_quant_fp8[(M, )](
x,
x_q,
x_s,
group_size,
x.shape[1],
x.stride(0),
eps,
fp8_min=fp8_min,
fp8_max=fp8_max,
BLOCK=BLOCK,
num_warps=num_warps,
num_stages=num_stages,
)
return x_q, x_s
@triton.jit
def _w8a8_block_fp8_matmul(
# Pointers to inputs and output
A,
B,
C,
As,
Bs,
# Shape for matmul
M,
N,
K,
# Block size for block-wise quantization
group_n,
group_k,
# Stride for inputs and output
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
stride_As_m,
stride_As_k,
stride_Bs_k,
stride_Bs_n,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
):
"""Triton-accelerated function used to perform linear operations (dot
product) on input tensors `A` and `B` with block-wise quantization, and
store the result in output tensor `C`.
"""
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
As_ptrs = As + offs_am * stride_As_m
offs_bsn = offs_bn // group_n
Bs_ptrs = Bs + offs_bsn * stride_Bs_n
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
a = tl.load(a_ptrs,
mask=offs_k[None, :] < K - k * BLOCK_SIZE_K,
other=0.0)
b = tl.load(b_ptrs,
mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,
other=0.0)
k_start = k * BLOCK_SIZE_K
offs_ks = k_start // group_k
a_s = tl.load(As_ptrs + offs_ks * stride_As_k)
b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k)
accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
if C.dtype.element_ty == tl.bfloat16:
c = accumulator.to(tl.bfloat16)
elif C.dtype.element_ty == tl.float16:
c = accumulator.to(tl.float16)
else:
c = accumulator.to(tl.float32)
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, c, mask=c_mask)
@functools.lru_cache
def get_w8a8_block_fp8_configs(N: int, K: int, block_n: int,
block_k: int) -> Optional[dict[int, Any]]:
"""
Return optimized configurations for the w8a8 block fp8 kernel.
The return value will be a dictionary that maps an irregular grid of
batch sizes to configurations of the w8a8 block fp8 kernel. To evaluate the
kernel on a given batch size bs, the closest batch size in the grid should
be picked and the associated configuration chosen to invoke the kernel.
"""
# First look up if an optimized configuration is available in the configs
# directory
device_name = current_platform.get_device_name().replace(" ", "_")
json_file_name = f"N={N},K={K},device_name={device_name},dtype=fp8_w8a8,block_shape=[{block_n},{block_k}].json" # noqa: E501
config_file_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name)
if os.path.exists(config_file_path):
with open(config_file_path) as f:
logger.info(
"Using configuration from %s for W8A8 Block FP8 kernel.",
config_file_path,
)
# If a configuration has been found, return it
return {int(key): val for key, val in json.load(f).items()}
# If no optimized configuration is available, we will use the default
# configuration
logger.warning(
"Using default W8A8 Block FP8 kernel config. Performance might "
"be sub-optimal! Config file not found at %s",
config_file_path,
)
return None
def w8a8_block_fp8_matmul(
def hipblaslt_w8a8_block_fp8_matmul(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
......@@ -574,80 +270,19 @@ def w8a8_block_fp8_matmul(
block_size: list[int],
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
"""This function performs matrix multiplication with block-wise
quantization.
It takes two input tensors `A` and `B` with scales `As` and `Bs`.
The output is returned in the specified `output_dtype`.
Args:
A: The input tensor, e.g., activation.
B: The input tensor, e.g., weight.
As: The per-token-group quantization scale for `A`.
Bs: The per-block quantization scale for `B`.
block_size: The block size for per-block quantization. It should
be 2-dim, e.g., [128, 128].
output_dytpe: The dtype of the returned tensor.
Returns:
torch.Tensor: The result of matmul.
"""
assert len(block_size) == 2
block_n, block_k = block_size[0], block_size[1]
assert A.shape[-1] == B.shape[-1]
assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous()
assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1]
M = A.numel() // A.shape[-1]
assert B.ndim == 2 and Bs.ndim == 2
N, K = B.shape
assert triton.cdiv(N, block_n) == Bs.shape[0]
assert triton.cdiv(K, block_k) == Bs.shape[1]
C_shape = A.shape[:-1] + (N, )
C = A.new_empty(C_shape, dtype=output_dtype)
configs = get_w8a8_block_fp8_configs(N, K, block_size[0], block_size[1])
if configs:
# Get the optimal config if there is one
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
m, k = A.shape
_, n = B.shape
enum_block_size = BlockSize.block_128x128
if block_size[0] == 64:
enum_block_size = BlockSize.block_64x64
elif block_size[0] == 128:
enum_block_size = BlockSize.block_128x128
else:
# Default config
# Block-wise quant: BLOCK_SIZE_N must be divisible by block_size[0]
# BLOCK_SIZE_K must be divisible by block_size[1]
config = {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": block_size[0],
"BLOCK_SIZE_K": block_size[1],
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 2,
}
def grid(META):
return (triton.cdiv(M, META["BLOCK_SIZE_M"]) *
triton.cdiv(N, META["BLOCK_SIZE_N"]), )
_w8a8_block_fp8_matmul[grid](
A,
B,
C,
As,
Bs,
M,
N,
K,
block_n,
block_k,
A.stride(-2),
A.stride(-1),
B.stride(1),
B.stride(0),
C.stride(-2),
C.stride(-1),
As.stride(-2),
As.stride(-1),
Bs.stride(1),
Bs.stride(0),
**config,
)
print(f"[WARN] Unsupported block_size: {block_size}. Falling back to BlockSize.block_128x128")
_, d = quant_ops.hipblaslt_w8a8_blockwise_gemm(A, B, As, Bs,
m, n, k, 'NN', output_dtype,
enum_block_size, None)
return d
return C
......@@ -11,7 +11,10 @@ from vllm.config import CompilationLevel, get_current_vllm_config
from vllm.platforms import current_platform
from vllm.utils import W8a8GetCacheJSON
from lmslim.layers.gemm.int8_utils import per_token_quant_int8
from lmslim.layers.gemm.fp8_utils import triton_scaled_mm_fp8
try:
from lmslim.layers.gemm.fp8_utils import triton_scaled_mm_fp8
except Exception:
print("INFO: Please updata lmslim if you want to use fp8_utils.\n")
# Input scaling factors are no longer optional in _scaled_mm starting
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
TORCH_DEVICE_IDENTITY = None
......
......@@ -232,6 +232,11 @@ def get_model_architecture(
'ChatGLMModel', 'Glm4ForCausalLM', 'ChatGLMForConditionalGeneration', 'BaichuanForCausalLM', 'BloomForCausalLM', 'TeleChat2ForCausalLM', 'MixtralForCausalLM', 'FalconForCausalLM',
'MedusaModel', 'MLPSpeculatorPreTrainedModel', 'DeepseekV2ForCausalLM', 'DeepseekV3ForCausalLM', 'DeepSeekMTPModel']
if any(arch in architectures for arch in support_nn_architectures):
#针对使用dtype为fp16的情况的量化默认关闭"VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD"
if model_config.quantization in {"awq", "awq_marlin", "moe_wna16"}:
if not envs.is_set("VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD"):
os.environ['VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD'] = '0'
if not envs.VLLM_USE_NN:
if os.getenv('LLAMA_NN') != '0':
if (architectures == ['QWenLMHeadModel'] or architectures == ['ChatGLMModel'] ) and visions != []:
......@@ -255,8 +260,12 @@ def get_model_architecture(
os.environ['VLLM_USE_CAT_MLA'] = '1'
if not envs.is_set("VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT"):
os.environ['VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT'] = '1'
# if not envs.is_set("VLLM_USE_FUSED_CACHE_QUANT_BMM_MLA"):
# os.environ['VLLM_USE_FUSED_CACHE_QUANT_BMM_MLA'] = '1'
if not envs.is_set("VLLM_SCHED_ENABLE_MINIMAL_INJECTION"):
os.environ['VLLM_SCHED_ENABLE_MINIMAL_INJECTION'] = '1'
if not envs.is_set("VLLM_USE_FUSED_FILL_RMS_CAT"):
os.environ['VLLM_USE_FUSED_FILL_RMS_CAT'] = '1'
if model_config.quantization in {"slimquant_w4a8", "slimquant_w4a8_marlin", "slimquant_compressed_tensors_marlin", "compressed-tensors"}:
if not envs.is_set("USE_FUSED_RMS_QUANT"):
os.environ['USE_FUSED_RMS_QUANT'] = '1'
......@@ -278,6 +287,8 @@ def get_model_architecture(
os.environ['VLLM_USE_FUSE_SILU_AND_MUL'] = '1'
if not envs.is_set("VLLM_USE_OPT_RESHAPE_AND_CACHE"):
os.environ['VLLM_USE_OPT_RESHAPE_AND_CACHE'] = '1'
if not envs.is_set("VLLM_USE_FUSED_RMS_ROPE"):
os.environ['VLLM_USE_FUSED_RMS_ROPE'] = '1'
if architectures in [['DeepseekV32ForCausalLM']]:
if not envs.is_set("VLLM_USE_V32_ENCODE"):
......@@ -287,7 +298,7 @@ def get_model_architecture(
if os.getenv('FA_PAD') != '1':
os.environ['FA_PAD'] = '0'
else:
if architectures in [['DeepseekV3ForCausalLM'], ['DeepSeekMTPModel']]:
if architectures in [['DeepseekV3ForCausalLM'], ['DeepSeekMTPModel']]:
if not envs.is_set("VLLM_USE_LIGHTOP"):
os.environ['VLLM_USE_LIGHTOP'] = '1'
if not envs.is_set("VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD"):
......@@ -298,8 +309,12 @@ def get_model_architecture(
os.environ['VLLM_USE_CAT_MLA'] = '1'
if not envs.is_set("VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT"):
os.environ['VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT'] = '1'
# if not envs.is_set("VLLM_USE_FUSED_CACHE_QUANT_BMM_MLA"):
# os.environ['VLLM_USE_FUSED_CACHE_QUANT_BMM_MLA'] = '1'
if not envs.is_set("VLLM_SCHED_ENABLE_MINIMAL_INJECTION"):
os.environ['VLLM_SCHED_ENABLE_MINIMAL_INJECTION'] = '1'
if not envs.is_set("VLLM_USE_FUSED_FILL_RMS_CAT"):
os.environ['VLLM_USE_FUSED_FILL_RMS_CAT'] = '1'
if model_config.quantization in {"slimquant_w4a8", "slimquant_w4a8_marlin", "slimquant_compressed_tensors_marlin", "compressed-tensors"}:
if not envs.is_set("USE_FUSED_RMS_QUANT"):
os.environ['USE_FUSED_RMS_QUANT'] = '1'
......@@ -321,6 +336,8 @@ def get_model_architecture(
os.environ['VLLM_USE_FUSE_SILU_AND_MUL'] = '1'
if not envs.is_set("VLLM_USE_OPT_RESHAPE_AND_CACHE"):
os.environ['VLLM_USE_OPT_RESHAPE_AND_CACHE'] = '1'
if not envs.is_set("VLLM_USE_FUSED_RMS_ROPE"):
os.environ['VLLM_USE_FUSED_RMS_ROPE'] = '1'
if architectures in [['DeepseekV32ForCausalLM']]:
if not envs.is_set("VLLM_USE_V32_ENCODE"):
......
......@@ -392,9 +392,12 @@ class DeepseekV2MoE(nn.Module):
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
# fp16 mode not fused quant
if i_q is not None:
i_q=iqis[0]
i_s=iqis[1]
final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits,
i_q=iqis[0], i_s=iqis[1])
i_q=i_q, i_s=i_s)
if shared_output is not None:
if hidden_states.dtype != torch.float16:
......@@ -436,9 +439,12 @@ class DeepseekV2MoE(nn.Module):
assert shared_output is not None
final_hidden_states += (shared_output * (1. / self.routed_scaling_factor))
else:
if i_q is not None:
i_q=iqis[0]
i_s=iqis[1]
final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits,
i_q=iqis[0], i_s=iqis[1])
i_q=i_q, i_s=i_s)
if shared_output is not None:
if hidden_states.dtype != torch.float16:
......@@ -839,6 +845,8 @@ class DeepseekV2MLAAttention(nn.Module):
k_pe,
output_shape=(hidden_states.shape[0],
self.num_local_heads * self.v_head_dim),
query_nope=q[..., :self.qk_nope_head_dim],
num_local_heads=self.num_local_heads,
q_ori=q,
key_normed=kv_c_normed,
positions=positions,
......@@ -887,6 +895,8 @@ class DeepseekV2MLAAttention(nn.Module):
k_pe,
output_shape=(hidden_states.shape[0],
self.num_local_heads * self.v_head_dim),
query_nope=q[..., :self.qk_nope_head_dim],
num_local_heads=self.num_local_heads,
q_ori=q,
key_normed=kv_c_normed,
positions=positions,
......@@ -946,6 +956,8 @@ class DeepseekV2MLAAttention(nn.Module):
k_pe,
output_shape=(hidden_states.shape[0],
self.num_local_heads * self.v_head_dim),
query_nope=q[..., :self.qk_nope_head_dim],
num_local_heads=self.num_local_heads,
q_ori=q,
key_normed=kv_c_normed,
positions=positions,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment