"vscode:/vscode.git/clone" did not exist on "bb00f66e19acdf6cb614683ab74f777ed3932eee"
Commit 7624bd05 authored by zhuwenwen's avatar zhuwenwen
Browse files

[qwen3-235b] MoE(TN&NN) configs for nmz TP=8

[qwen3-480b] MoE(TN) configs for nmz TP=8
[opt] 优化deepep相关代码
[fix] 修复deepseek moe模型的awq量化推理bug和精度问题, 修复awq模型的VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD设置位置, update_state,优化性能,去除冗余操作
pcie 解决custom cudagraph模式需要拷贝的问题,需要配合dtk进行使用
[feat] Switch default w8a8 gemm impl to blaslt. Support w8a8-fp8 GEMM backend.MoE 路由抓取:新增 router_capture 工具链与 envs 统一配置
[envs] set VLLM_CUSTOM_CACHE=1、VLLM_USE_FUSED_RMS_ROPE=1、VLLM_USE_FUSED_FILL_RMS_CAT=1、VLLM_USE_FLASH_ATTN_FP8=1、VLLM_USE_FLASH_MLA_FP8=1、update VLLM_USE_TOPK_RENORM
parent ad1d74cf
...@@ -1056,9 +1056,22 @@ class CustomAllreduce { ...@@ -1056,9 +1056,22 @@ class CustomAllreduce {
size /= d; size /= d;
auto bytes = size * sizeof(typename packed_t<T>::P); auto bytes = size * sizeof(typename packed_t<T>::P);
int blocks = std::min(block_limit, (size + threads - 1) / threads); int blocks = std::min(block_limit, (size + threads - 1) / threads);
// #define KL(ngpus, name) \
// name<T, ngpus><<<blocks, threads, 0, stream>>>(ptrs, sg_, self_sg_, output, \
// rank_, size, dev_curr_hdp_reg, world_size_) ;
#define KL(ngpus, name) \ #define KL(ngpus, name) \
name<T, ngpus><<<blocks, threads, 0, stream>>>(ptrs, sg_, self_sg_, output, \ { \
rank_, size, dev_curr_hdp_reg, world_size_) ; void* kernelArgs[] = { \
&ptrs, &sg_, &self_sg_, &output, &rank_, &size \
}; \
hipExtLaunchKernel( \
(void*)name<T, ngpus>, \
blocks, threads, \
kernelArgs, 0, \
stream, nullptr, stopEvent, 0 \
); \
}
#define REDUCE_CASE(ngpus) \ #define REDUCE_CASE(ngpus) \
case ngpus: { \ case ngpus: { \
......
...@@ -559,10 +559,10 @@ def get_version_add(sha: Optional[str] = None) -> str: ...@@ -559,10 +559,10 @@ def get_version_add(sha: Optional[str] = None) -> str:
if sha is None: if sha is None:
sha = get_sha(vllm_root) sha = get_sha(vllm_root)
if (major, minor) >= ('2', '5'): if (major, minor) >= ('2', '5'):
version = 'das.opt3.' + sha[:7] version = 'das.opt4.' + sha[:7]
else: else:
if (major, minor) >= ('2', '5'): if (major, minor) >= ('2', '5'):
version = 'das.opt3' version = 'das.opt4'
# dtk version # dtk version
......
...@@ -4822,13 +4822,9 @@ class VllmConfig: ...@@ -4822,13 +4822,9 @@ class VllmConfig:
if ep_sp or enable_dp_attention: if ep_sp or enable_dp_attention:
batch_size_capture_list = sorted(set([round_up(i, tp_size) for i in batch_size_capture_list])) batch_size_capture_list = sorted(set([round_up(i, tp_size) for i in batch_size_capture_list]))
if 1 not in batch_size_capture_list:
batch_size_capture_list.insert(0, 1)
else: else:
if ep_sp or enable_dp_attention: if ep_sp or enable_dp_attention:
batch_size_capture_list = sorted(set([round_up(i, tp_size) for i in batch_size_capture_list])) batch_size_capture_list = sorted(set([round_up(i, tp_size) for i in batch_size_capture_list]))
if 1 not in batch_size_capture_list:
batch_size_capture_list.insert(0, 1)
self.compilation_config.init_with_cudagraph_sizes( self.compilation_config.init_with_cudagraph_sizes(
batch_size_capture_list) batch_size_capture_list)
......
...@@ -103,7 +103,7 @@ class DeviceCommunicatorBase: ...@@ -103,7 +103,7 @@ class DeviceCommunicatorBase:
# as long as we use data parallel (coupled data parallel # as long as we use data parallel (coupled data parallel
# where all data parallel ranks execute forward together), # where all data parallel ranks execute forward together),
# we initialize the all2all manager used in expert parallel. # we initialize the all2all manager used in expert parallel.
use_ep = config.parallel_config.data_parallel_size > 1 use_ep = config.parallel_config.data_parallel_size > 1 and not config.parallel_config.enable_dp_attention
self.use_all2all = "ep" in unique_name and use_ep self.use_all2all = "ep" in unique_name and use_ep
self.all2all_manager: Optional[All2AllManagerBase] = None self.all2all_manager: Optional[All2AllManagerBase] = None
......
...@@ -271,10 +271,7 @@ class CustomAllreduce: ...@@ -271,10 +271,7 @@ class CustomAllreduce:
if envs.VLLM_CUSTOM_CACHE: if envs.VLLM_CUSTOM_CACHE:
return self.all_reduce(input, registered=True) return self.all_reduce(input, registered=True)
else: else:
if not self.fully_connected:
return self.all_reduce(input, registered=False) return self.all_reduce(input, registered=False)
else:
return self.all_reduce(input, registered=True)
else: else:
# If warm up, mimic the allocation pattern since custom # If warm up, mimic the allocation pattern since custom
# allreduce is out-of-place. # allreduce is out-of-place.
......
...@@ -202,11 +202,19 @@ if TYPE_CHECKING: ...@@ -202,11 +202,19 @@ if TYPE_CHECKING:
VLLM_USE_MARLIN_W16A16_MOE:bool = False VLLM_USE_MARLIN_W16A16_MOE:bool = False
VLLM_V1_USE_REDUCED_TOPK_TOPP_SAMPLER: bool = False VLLM_V1_USE_REDUCED_TOPK_TOPP_SAMPLER: bool = False
VLLM_USE_FUSED_FILL_RMS_CAT:bool = False VLLM_USE_FUSED_FILL_RMS_CAT:bool = False
VLLM_MOE_ROUTER_CAPTURE: bool = False
VLLM_MOE_ROUTER_CAPTURE_DIR: str = "/tmp"
VLLM_MOE_ROUTER_CAPTURE_RANK: int = -1
VLLM_MOE_ROUTER_CAPTURE_MAX_LAYERS: int = 0
VLLM_MOE_ROUTER_CAPTURE_NUM_TOKENS_GT: int = -1
VLLM_MOE_ROUTER_CAPTURE_NUM_TOKENS_LT: int = -1
VLLM_ENABLE_DEEPEP_HT_DEEPGEMM: bool = True VLLM_ENABLE_DEEPEP_HT_DEEPGEMM: bool = True
VLLM_ENABLE_DEEPEP_INT8_DISPATCH: bool = True
VLLM_ZERO_OVERHEAD_ENHANCE: bool = False VLLM_ZERO_OVERHEAD_ENHANCE: bool = False
VLLM_USE_FUSED_QA_KVA_GEMM: bool = False VLLM_USE_FUSED_QA_KVA_GEMM: bool = False
VLLM_V1_FAST_TOKEN_ID_COPY: bool = False VLLM_V1_FAST_TOKEN_ID_COPY: bool = False
VLLM_DISABLE_SHARED_EXPERTS_STREAM:bool = True VLLM_DISABLE_SHARED_EXPERTS_STREAM:bool = True
VLLM_W8A8_BACKEND: int = 3
VLLM_ENABLE_SHARED_EXPERTS_FUSION: bool = False VLLM_ENABLE_SHARED_EXPERTS_FUSION: bool = False
VLLM_USE_MOE_W16A16_TRITON: bool = False VLLM_USE_MOE_W16A16_TRITON: bool = False
VLLM_USE_FUSED_DTBMM: bool = False VLLM_USE_FUSED_DTBMM: bool = False
...@@ -1064,7 +1072,7 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1064,7 +1072,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
# If set, vLLM will use FLASH ATTN fp8 attention optimizations. # If set, vLLM will use FLASH ATTN fp8 attention optimizations.
"VLLM_USE_FLASH_ATTN_FP8": "VLLM_USE_FLASH_ATTN_FP8":
lambda: bool(int(os.getenv("VLLM_USE_FLASH_ATTN_FP8", "0"))), lambda: bool(int(os.getenv("VLLM_USE_FLASH_ATTN_FP8", "1"))),
# If set, vLLM will use FLASH MLA attention optimizations. # If set, vLLM will use FLASH MLA attention optimizations.
"VLLM_USE_FLASH_MLA": "VLLM_USE_FLASH_MLA":
...@@ -1072,7 +1080,7 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1072,7 +1080,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
# If set, vLLM will use FLASH MLA fp8 attention optimizations. # If set, vLLM will use FLASH MLA fp8 attention optimizations.
"VLLM_USE_FLASH_MLA_FP8": "VLLM_USE_FLASH_MLA_FP8":
lambda: bool(int(os.getenv("VLLM_USE_FLASH_MLA_FP8", "0"))), lambda: bool(int(os.getenv("VLLM_USE_FLASH_MLA_FP8", "1"))),
# flag to control vllm to use optimized kernels # flag to control vllm to use optimized kernels
"VLLM_USE_OPT_OP": "VLLM_USE_OPT_OP":
...@@ -1099,7 +1107,7 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1099,7 +1107,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
# flag to control vllm to use optimized kernels # flag to control vllm to use optimized kernels
"VLLM_CUSTOM_CACHE": "VLLM_CUSTOM_CACHE":
lambda: bool(int(os.environ.get("VLLM_CUSTOM_CACHE", "0"))), lambda: bool(int(os.environ.get("VLLM_CUSTOM_CACHE", "1"))),
# flag to control vllm to use optimized kernels # flag to control vllm to use optimized kernels
"VLLM_CUSTOM_ALLREDUCE_SUPPORTED_WORLDSIZE_MAX": "VLLM_CUSTOM_ALLREDUCE_SUPPORTED_WORLDSIZE_MAX":
...@@ -1299,7 +1307,7 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1299,7 +1307,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
# vLLM will use fused RMS + RoPE kernel # vLLM will use fused RMS + RoPE kernel
"VLLM_USE_FUSED_RMS_ROPE": "VLLM_USE_FUSED_RMS_ROPE":
lambda: (os.environ.get("VLLM_USE_FUSED_RMS_ROPE", "False").lower() in lambda: (os.environ.get("VLLM_USE_FUSED_RMS_ROPE", "True").lower() in
("true", "1")), ("true", "1")),
# vLLM will use Marlin W16A16 kernel for MoE experts # vLLM will use Marlin W16A16 kernel for MoE experts
"VLLM_USE_MARLIN_W16A16_MOE": "VLLM_USE_MARLIN_W16A16_MOE":
...@@ -1316,11 +1324,38 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1316,11 +1324,38 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda: (os.getenv("VLLM_V1_USE_REDUCED_TOPK_TOPP_SAMPLER", lambda: (os.getenv("VLLM_V1_USE_REDUCED_TOPK_TOPP_SAMPLER",
"0").lower() in ("true", "1")), "0").lower() in ("true", "1")),
# Capture MoE router logits for debugging/analysis.
"VLLM_MOE_ROUTER_CAPTURE":
lambda: (os.getenv("VLLM_MOE_ROUTER_CAPTURE", "0").lower() in ("true", "1")),
# Output directory for MoE router capture dumps.
"VLLM_MOE_ROUTER_CAPTURE_DIR":
lambda: os.environ.get(
"VLLM_MOE_ROUTER_CAPTURE_DIR",
"/tmp",
),
# Capture only the specified rank; set to -1 to capture all ranks.
"VLLM_MOE_ROUTER_CAPTURE_RANK":
lambda: int(os.environ.get("VLLM_MOE_ROUTER_CAPTURE_RANK", "-1")),
# Max number of MoE layers to record per process (0 = unlimited).
"VLLM_MOE_ROUTER_CAPTURE_MAX_LAYERS":
lambda: int(os.environ.get("VLLM_MOE_ROUTER_CAPTURE_MAX_LAYERS", "0")),
# Only capture when num_tokens > N (negative disables).
"VLLM_MOE_ROUTER_CAPTURE_NUM_TOKENS_GT":
lambda: int(os.environ.get("VLLM_MOE_ROUTER_CAPTURE_NUM_TOKENS_GT", "-1")),
# Only capture when num_tokens < N (0 disables).
"VLLM_MOE_ROUTER_CAPTURE_NUM_TOKENS_LT":
lambda: int(os.environ.get("VLLM_MOE_ROUTER_CAPTURE_NUM_TOKENS_LT", "-1")),
# vLLM will use deepgemm kernel for deepep ht mode # vLLM will use deepgemm kernel for deepep ht mode
"VLLM_ENABLE_DEEPEP_HT_DEEPGEMM": "VLLM_ENABLE_DEEPEP_HT_DEEPGEMM":
lambda: (os.getenv('VLLM_ENABLE_DEEPEP_HT_DEEPGEMM', '1').lower() in lambda: (os.getenv('VLLM_ENABLE_DEEPEP_HT_DEEPGEMM', '1').lower() in
("true", "1")), ("true", "1")),
# vLLM will use deepep int8 dispatch
"VLLM_ENABLE_DEEPEP_INT8_DISPATCH":
lambda: (os.getenv('VLLM_ENABLE_DEEPEP_INT8_DISPATCH', '1').lower() in
("true", "1")),
# Only quantized DeepSeek models supported. # Only quantized DeepSeek models supported.
# Unquantized versions are not supported. # Unquantized versions are not supported.
"VLLM_USE_FUSED_QA_KVA_GEMM": "VLLM_USE_FUSED_QA_KVA_GEMM":
...@@ -1340,6 +1375,12 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1340,6 +1375,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_DISABLE_SHARED_EXPERTS_STREAM": lambda: bool( "VLLM_DISABLE_SHARED_EXPERTS_STREAM": lambda: bool(
int(os.getenv("VLLM_DISABLE_SHARED_EXPERTS_STREAM", "1")) int(os.getenv("VLLM_DISABLE_SHARED_EXPERTS_STREAM", "1"))
), ),
# W8A8 GEMM backend selection for vLLM quantized models.
# lightop/triton: 1
# cutlass: 2 (will remove in the future)
# blaslt: 3 (default)
# rocblas: others
"VLLM_W8A8_BACKEND": lambda: int(os.getenv("VLLM_W8A8_BACKEND", "3")),
# shared experts fusion # shared experts fusion
# VLLM_ENABLE_SHARED_EXPERTS_FUSION = 1 enable shared experts fusion # VLLM_ENABLE_SHARED_EXPERTS_FUSION = 1 enable shared experts fusion
# VLLM_ENABLE_SHARED_EXPERTS_FUSION = 0 disable shared experts fusion # VLLM_ENABLE_SHARED_EXPERTS_FUSION = 0 disable shared experts fusion
...@@ -1448,6 +1489,7 @@ def compute_hash() -> str: ...@@ -1448,6 +1489,7 @@ def compute_hash() -> str:
"VLLM_DP_SIZE", "VLLM_DP_SIZE",
"VLLM_USE_STANDALONE_COMPILE", "VLLM_USE_STANDALONE_COMPILE",
"VLLM_FUSED_MOE_CHUNK_SIZE", "VLLM_FUSED_MOE_CHUNK_SIZE",
"VLLM_W8A8_BACKEND",
] ]
for key in environment_variables_to_hash: for key in environment_variables_to_hash:
if key in environment_variables: if key in environment_variables:
......
...@@ -136,8 +136,8 @@ def set_forward_context( ...@@ -136,8 +136,8 @@ def set_forward_context(
forward_start_time = time.perf_counter() forward_start_time = time.perf_counter()
dp_metadata: Optional[DPMetadata] = None dp_metadata: Optional[DPMetadata] = None
dp_size = vllm_config.parallel_config.data_parallel_size dp_size = vllm_config.parallel_config.data_parallel_size
use_navie_ep = envs.VLLM_ALL2ALL_BACKEND == 'naive' and dp_size > 1 and vllm_config.parallel_config.enable_expert_parallel use_navie_all2all = envs.VLLM_ALL2ALL_BACKEND == 'naive' and dp_size > 1
if use_navie_ep and dp_size > 1 and ( if use_navie_all2all and dp_size > 1 and (
attn_metadata is not None or num_tokens is not None): attn_metadata is not None or num_tokens is not None):
dp_metadata = DPMetadata.make(vllm_config.parallel_config, dp_metadata = DPMetadata.make(vllm_config.parallel_config,
attn_metadata, num_tokens or 0, attn_metadata, num_tokens or 0,
......
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2
},
"2": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2
},
"256": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 2
},
"1024": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"3072": {
"BLOCK_SIZE_M": 256,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
}
}
\ No newline at end of file
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
},
"2": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"num_ldmatrixes": 1
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"num_ldmatrixes": 1
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"num_ldmatrixes": 1
},
"48": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 16,
"num_warps": 2,
"num_stages": 2,
"num_ldmatrixes": 1
},
"96": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"num_ldmatrixes": 1
},
"256": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
},
"512": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
}
}
\ No newline at end of file
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"2": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"512": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
}
}
\ No newline at end of file
...@@ -1226,14 +1226,14 @@ def vllm_topk_softmax(topk_weights: torch.Tensor, topk_indices: torch.Tensor, ...@@ -1226,14 +1226,14 @@ def vllm_topk_softmax(topk_weights: torch.Tensor, topk_indices: torch.Tensor,
token_expert_indices: torch.Tensor, token_expert_indices: torch.Tensor,
gating_output: torch.Tensor, gating_output: torch.Tensor,
renormalize: bool) -> tuple[torch.Tensor, ...]: renormalize: bool) -> tuple[torch.Tensor, ...]:
if envs.VLLM_USE_TOPK_RENORM: if envs.VLLM_USE_TOPK_RENORM and renormalize is True:
from lightop import op as op from lightop import op as op
op.topk_softmax( op.topk_softmax(
topk_weights, topk_weights,
topk_indices, topk_indices,
token_expert_indices, token_expert_indices,
gating_output, gating_output,
True, renormalize,
) )
else: else:
ops.topk_softmax( ops.topk_softmax(
......
...@@ -192,7 +192,7 @@ class FusedMoEMethodBase(QuantizeMethodBase): ...@@ -192,7 +192,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
and moe.quant_config.block_shape and moe.quant_config.block_shape
== DEEPEP_QUANT_BLOCK_SHAPE) == DEEPEP_QUANT_BLOCK_SHAPE)
use_int8_dispatch = False use_int8_dispatch = moe.quant_config.quant_dtype == torch.int8 and envs.VLLM_ENABLE_DEEPEP_HT_DEEPGEMM
ll_prepare_finalize = DeepEPLLPrepareAndFinalize( ll_prepare_finalize = DeepEPLLPrepareAndFinalize(
ll_handle, ll_handle,
...@@ -249,7 +249,7 @@ class FusedMoEMethodBase(QuantizeMethodBase): ...@@ -249,7 +249,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
and moe.quant_config.block_shape and moe.quant_config.block_shape
== DEEPEP_QUANT_BLOCK_SHAPE) == DEEPEP_QUANT_BLOCK_SHAPE)
use_int8_dispatch = moe.quant_config.quant_dtype == torch.int8 use_int8_dispatch = moe.quant_config.quant_dtype == torch.int8 and envs.VLLM_ENABLE_DEEPEP_HT_DEEPGEMM
# Note (varun): Whether to use FP8 dispatch or not needs some # Note (varun): Whether to use FP8 dispatch or not needs some
# profiling. Turning it off for now. # profiling. Turning it off for now.
......
"""
Utilities for capturing MoE router distributions from real workloads.
This is intentionally lightweight and gated behind env vars so it has zero
runtime impact unless explicitly enabled.
Env vars (defaults from vllm.envs):
- VLLM_MOE_ROUTER_CAPTURE=0/1: enable capture (default: 0).
- VLLM_MOE_ROUTER_CAPTURE_DIR=/path: output directory for per-process dumps
(default: /tmp).
- VLLM_MOE_ROUTER_CAPTURE_RANK=N: only capture on the given torch.distributed
rank (default: -1; set to -1 to capture all ranks).
- VLLM_MOE_ROUTER_CAPTURE_MAX_LAYERS=N: max number of layers to record per
process (default: 0; 0 = unlimited).
- VLLM_MOE_ROUTER_CAPTURE_NUM_TOKENS_GT=A: only record calls where router_logits
has num_tokens > A (default: -1; <0 = disabled).
- VLLM_MOE_ROUTER_CAPTURE_NUM_TOKENS_LT=B: only record calls where router_logits
has num_tokens < B (default: -1; 0 = disabled).
Output format:
- A single `.pt` per captured num_tokens (and per rank if torch.distributed is
initialized).
- Payload includes `layers_by_num_tokens: dict[str, dict[layer_name, layer_state]]`.
- A convenience `layers` field is also included (same as
`layers_by_num_tokens[str(num_tokens)]`) for easy loading.
- For each captured MoE layer, stores a list of 2D tensors
`router_logits_chunks: list[Tensor[num_tokens_i, num_experts]]` on CPU,
typically in fp16 for space efficiency.
"""
from __future__ import annotations
import atexit
import inspect
import os
import socket
import threading
import time
from dataclasses import dataclass
from typing import Optional
import torch
import vllm.envs as envs
_DEFAULT_SKIP_STACK_FUNCS = ("profile_run", "_dummy_run",
"determine_available_memory")
@dataclass(frozen=True)
class RouterCaptureConfig:
enabled: bool = False
out_dir: str = "/tmp"
skip_profile: bool = True
skip_stack_funcs: tuple[str, ...] = _DEFAULT_SKIP_STACK_FUNCS
only_rank: Optional[int] = 0
max_layers: int = 0
num_tokens_gt: Optional[int] = None
num_tokens_lt: Optional[int] = None
@staticmethod
def from_env() -> "RouterCaptureConfig":
enabled = envs.VLLM_MOE_ROUTER_CAPTURE
out_dir = envs.VLLM_MOE_ROUTER_CAPTURE_DIR
skip_profile = True
skip_stack_funcs = _DEFAULT_SKIP_STACK_FUNCS
only_rank: Optional[int] = None
if envs.VLLM_MOE_ROUTER_CAPTURE_RANK >= 0:
only_rank = envs.VLLM_MOE_ROUTER_CAPTURE_RANK
max_layers = envs.VLLM_MOE_ROUTER_CAPTURE_MAX_LAYERS
num_tokens_gt_opt = (envs.VLLM_MOE_ROUTER_CAPTURE_NUM_TOKENS_GT
if envs.VLLM_MOE_ROUTER_CAPTURE_NUM_TOKENS_GT >= 0
else None)
num_tokens_lt_opt = (envs.VLLM_MOE_ROUTER_CAPTURE_NUM_TOKENS_LT
if envs.VLLM_MOE_ROUTER_CAPTURE_NUM_TOKENS_LT > 0
else None)
# Per-size mode requires an explicit token-count filter to avoid
# unbounded captures by default.
if num_tokens_gt_opt is None and num_tokens_lt_opt is None:
enabled = False
if (num_tokens_gt_opt is not None and num_tokens_lt_opt is not None
and num_tokens_gt_opt >= num_tokens_lt_opt):
enabled = False
return RouterCaptureConfig(enabled=enabled,
out_dir=out_dir,
skip_profile=skip_profile,
skip_stack_funcs=skip_stack_funcs,
only_rank=only_rank,
max_layers=max_layers,
num_tokens_gt=num_tokens_gt_opt,
num_tokens_lt=num_tokens_lt_opt)
def _in_profile_run(skip_stack_funcs: tuple[str, ...]) -> bool:
"""
Best-effort detection for vLLM startup profiling/warmup runs.
Startup warmups often execute MoE kernels with synthetic shapes. When
enabled, skip captures from these stacks so the first capture comes from a
real request.
"""
if not skip_stack_funcs:
return False
frame = inspect.currentframe()
try:
while frame is not None:
name = frame.f_code.co_name
if name in skip_stack_funcs:
return True
frame = frame.f_back
finally:
# Avoid reference cycles.
del frame
return False
class _RouterCapture:
def __init__(self, cfg: RouterCaptureConfig) -> None:
self.cfg = cfg
# Bucket captures by token count.
self._layers_by_num_tokens: dict[int, dict[str, dict[str, object]]] = {}
self._layer_names: set[str] = set()
self._completed_num_tokens: set[int] = set()
self._lock = threading.Lock()
self._flush_counter = 0
self._pid = os.getpid()
self._host = socket.gethostname()
self._start_time = time.time()
os.makedirs(cfg.out_dir, exist_ok=True)
atexit.register(self.flush)
def _bucket_for_num_tokens(self, num_tokens: int) -> Optional[int]:
"""Return the per-size bucket key for this record call, or None if filtered."""
if self.cfg.num_tokens_gt is None and self.cfg.num_tokens_lt is None:
return None
if self.cfg.num_tokens_gt is not None:
if int(num_tokens) <= int(self.cfg.num_tokens_gt):
return None
if self.cfg.num_tokens_lt is not None:
if int(num_tokens) >= int(self.cfg.num_tokens_lt):
return None
bucket_num_tokens = int(num_tokens)
if bucket_num_tokens != 0 and bucket_num_tokens in self._completed_num_tokens:
return None
return bucket_num_tokens
def _snapshot_layers_by_num_tokens(
self,
layers_by_num_tokens: dict[int, dict[str, dict[str, object]]],
) -> dict[int, dict[str, dict[str, object]]]:
snapshot: dict[int, dict[str, dict[str, object]]] = {}
for num_tokens, bucket in layers_by_num_tokens.items():
bucket_snapshot: dict[str, dict[str, object]] = {}
for layer_name, state in bucket.items():
chunks = state.get("router_logits_chunks", [])
bucket_snapshot[layer_name] = {
"num_experts": int(state.get("num_experts", 0)),
"num_tokens": int(state.get("num_tokens", 0)),
"router_logits_chunks": list(chunks),
}
snapshot[int(num_tokens)] = bucket_snapshot
return snapshot
@torch.no_grad()
def record(self, layer_name: str, router_logits: torch.Tensor,
top_k: int) -> None:
if self.cfg.skip_profile and _in_profile_run(self.cfg.skip_stack_funcs):
return
if self.cfg.only_rank is not None:
rank = _get_rank()
if rank is not None and rank != self.cfg.only_rank:
return
if router_logits.dim() != 2:
return
num_tokens, num_experts = router_logits.shape
if num_tokens == 0 or num_experts == 0:
return
bucket_num_tokens = self._bucket_for_num_tokens(int(num_tokens))
if bucket_num_tokens is None:
return
# Limit the number of recorded layers to avoid unbounded dumps.
if layer_name not in self._layer_names:
if self.cfg.max_layers != 0 and len(self._layer_names) >= self.cfg.max_layers:
return
self._layer_names.add(layer_name)
# Store on CPU to avoid consuming GPU memory during long runs.
# fp16 is typically sufficient because we primarily care about
# distribution and relative ordering (top-k), not exact values.
router_logits_cpu = router_logits.detach()
if router_logits_cpu.is_cuda:
router_logits_cpu = router_logits_cpu.to(device="cpu",
dtype=torch.float16)
else:
router_logits_cpu = router_logits_cpu.to(dtype=torch.float16)
bucket_snapshot: Optional[dict[str, dict[str, object]]] = None
should_flush = False
with self._lock:
bucket = self._layers_by_num_tokens.setdefault(bucket_num_tokens, {})
if layer_name in bucket:
return
bucket[layer_name] = {
"num_experts": int(num_experts),
"num_tokens": int(num_tokens),
"router_logits_chunks": [router_logits_cpu],
}
if self.cfg.max_layers != 0 and len(bucket) >= int(self.cfg.max_layers):
should_flush = True
bucket_snapshot = self._snapshot_layers_by_num_tokens(
{int(bucket_num_tokens): bucket})[int(bucket_num_tokens)]
self._completed_num_tokens.add(int(bucket_num_tokens))
self._layers_by_num_tokens.pop(int(bucket_num_tokens), None)
if should_flush and bucket_snapshot is not None:
self._flush_payload(
layers_by_num_tokens={int(bucket_num_tokens): bucket_snapshot},
file_tag=f"nt{int(bucket_num_tokens)}",
)
def _flush_payload(
self,
*,
layers_by_num_tokens: dict[int, dict[str, dict[str, object]]],
file_tag: Optional[str] = None,
) -> Optional[str]:
if not self.cfg.enabled:
return None
if self.cfg.only_rank is not None:
rank = _get_rank()
if rank is not None and rank != self.cfg.only_rank:
return None
rank = _get_rank()
now = time.time()
ts = time.strftime("%Y%m%d_%H%M%S", time.localtime(now))
ts_us = int(now * 1_000_000)
with self._lock:
flush_idx = self._flush_counter
self._flush_counter += 1
rank_str = f"rank{rank}" if rank is not None else "rankNA"
tag = f"{file_tag}_" if file_tag else ""
out_path = os.path.join(
self.cfg.out_dir,
f"moe_router_stats_{tag}{ts_us}_{self._host}_{rank_str}_pid{self._pid}_flush{flush_idx}.pt",
)
layers_by_num_tokens_out: dict[str, object] = {}
for num_tokens, bucket in layers_by_num_tokens.items():
bucket_out: dict[str, object] = {}
for layer_name, state in bucket.items():
bucket_out[layer_name] = {
"num_experts": int(state["num_experts"]),
"num_tokens": int(state["num_tokens"]),
"router_logits_chunks":
state["router_logits_chunks"], # type: ignore[typeddict-item]
}
layers_by_num_tokens_out[str(int(num_tokens))] = bucket_out
payload: dict[str, object] = {
"meta": {
"timestamp": ts,
"timestamp_us": ts_us,
"flush_index": int(flush_idx),
"host": self._host,
"pid": self._pid,
"rank": rank,
"wall_time_s": float(now - self._start_time),
},
"layers_by_num_tokens": layers_by_num_tokens_out,
}
# Backward-compatible convenience field when there is a single bucket.
if len(layers_by_num_tokens) == 1:
(only_bucket_key, ) = layers_by_num_tokens.keys()
payload["layers"] = layers_by_num_tokens_out[str(int(only_bucket_key))]
try:
torch.save(payload, out_path)
except Exception:
return None
return out_path
def flush(self) -> Optional[str]:
with self._lock:
if not self._layers_by_num_tokens:
return None
snapshot = self._snapshot_layers_by_num_tokens(self._layers_by_num_tokens)
return self._flush_payload(layers_by_num_tokens=snapshot)
def reset(self) -> None:
with self._lock:
self._layers_by_num_tokens.clear()
self._layer_names.clear()
self._completed_num_tokens.clear()
self._start_time = time.time()
_CAPTURE: Optional[_RouterCapture] = None
_CAPTURE_DISABLED: bool = False
def _disable_global_capture() -> None:
global _CAPTURE, _CAPTURE_DISABLED
_CAPTURE = None
_CAPTURE_DISABLED = True
def _get_rank() -> Optional[int]:
if torch.distributed.is_available() and torch.distributed.is_initialized():
try:
return torch.distributed.get_rank()
except Exception:
return None
return None
def _get_capture() -> Optional[_RouterCapture]:
global _CAPTURE, _CAPTURE_DISABLED
if _CAPTURE_DISABLED:
return None
if _CAPTURE is not None:
return _CAPTURE
cfg = RouterCaptureConfig.from_env()
if not cfg.enabled:
_disable_global_capture()
return None
if cfg.only_rank is not None:
rank = _get_rank()
if rank is not None and rank != cfg.only_rank:
_disable_global_capture()
return None
_CAPTURE = _RouterCapture(cfg)
return _CAPTURE
@torch.no_grad()
def maybe_record_router_logits(*, layer_name: str, router_logits: torch.Tensor,
top_k: int) -> None:
capture = _get_capture()
if capture is None:
return
capture.record(layer_name=layer_name, router_logits=router_logits, top_k=top_k)
def maybe_flush_router_capture(*, reset: bool = False) -> Optional[str]:
"""Flush capture buffers to disk without exiting the process."""
capture = _get_capture()
if capture is None:
return None
out_path = capture.flush()
if out_path is not None and reset:
capture.reset()
return out_path
\ No newline at end of file
...@@ -20,6 +20,7 @@ from vllm.model_executor.parameter import (BasevLLMParameter, ...@@ -20,6 +20,7 @@ from vllm.model_executor.parameter import (BasevLLMParameter,
PerTensorScaleParameter) PerTensorScaleParameter)
from vllm.utils import W8a8GetCacheJSON from vllm.utils import W8a8GetCacheJSON
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
import vllm.envs as envs
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -31,8 +32,8 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme): ...@@ -31,8 +32,8 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
input_symmetric: bool): input_symmetric: bool):
self.strategy = strategy self.strategy = strategy
self.is_static_input_scheme = is_static_input_scheme self.is_static_input_scheme = is_static_input_scheme
self.tritonsingleton= W8a8GetCacheJSON() self.tritonsingleton = W8a8GetCacheJSON()
self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '1')) self.w8a8_strategy = envs.VLLM_W8A8_BACKEND
self.input_symmetric = input_symmetric self.input_symmetric = input_symmetric
@classmethod @classmethod
......
...@@ -331,6 +331,10 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -331,6 +331,10 @@ class Fp8LinearMethod(LinearMethodBase):
weight = layer.weight.data weight = layer.weight.data
weight_scale_inv = layer.weight_scale_inv.data weight_scale_inv = layer.weight_scale_inv.data
if envs.VLLM_W8A8_BACKEND == 3:
weight = weight.T.contiguous()
weight_scale_inv = weight_scale_inv.T.contiguous()
else:
weight = self._maybe_pad_weight(weight) weight = self._maybe_pad_weight(weight)
# Torch.compile cannot use Parameter subclasses. # Torch.compile cannot use Parameter subclasses.
......
...@@ -92,8 +92,8 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase): ...@@ -92,8 +92,8 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase):
def __init__(self, quantization_config: SlimQuantW4A8Int8Config): def __init__(self, quantization_config: SlimQuantW4A8Int8Config):
self.quantization_config = quantization_config self.quantization_config = quantization_config
self.tritonsingleton= W8a8GetCacheJSON() self.tritonsingleton = W8a8GetCacheJSON()
self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '1')) self.w8a8_strategy = envs.VLLM_W8A8_BACKEND
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
n=layer.weight.shape[0] n=layer.weight.shape[0]
......
...@@ -6,6 +6,8 @@ import functools ...@@ -6,6 +6,8 @@ import functools
import json import json
import os import os
from typing import Any, Callable, Optional, Union, List from typing import Any, Callable, Optional, Union, List
from lmslim import quant_ops
from lmslim.quantize.quant_ops import BlockSize
import torch import torch
...@@ -83,7 +85,7 @@ if current_platform.is_rocm(): ...@@ -83,7 +85,7 @@ if current_platform.is_rocm():
def dispatch_w8a8_blockscale_func( def dispatch_w8a8_blockscale_func(
use_cutlass: bool, use_aiter_and_is_supported: bool use_cutlass: bool, use_aiter_and_is_supported: bool, use_blaslt: bool
) -> Callable[[ ) -> Callable[[
torch.Tensor, torch.Tensor,
torch.Tensor, torch.Tensor,
...@@ -96,6 +98,8 @@ def dispatch_w8a8_blockscale_func( ...@@ -96,6 +98,8 @@ def dispatch_w8a8_blockscale_func(
return cutlass_scaled_mm return cutlass_scaled_mm
if (use_aiter_and_is_supported): if (use_aiter_and_is_supported):
return torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale return torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale
if use_blaslt:
return hipblaslt_w8a8_block_fp8_matmul
return w8a8_block_fp8_matmul return w8a8_block_fp8_matmul
...@@ -127,6 +131,10 @@ def apply_w8a8_block_fp8_linear( ...@@ -127,6 +131,10 @@ def apply_w8a8_block_fp8_linear(
assert input_scale is None assert input_scale is None
# View input as 2D matrix for fp8 methods # View input as 2D matrix for fp8 methods
input_2d = input.view(-1, input.shape[-1]) input_2d = input.view(-1, input.shape[-1])
output_shape = []
if envs.VLLM_W8A8_BACKEND == 3:
output_shape = [*input.shape[:-1], weight.shape[-1]]
else:
output_shape = [*input.shape[:-1], weight.shape[0]] output_shape = [*input.shape[:-1], weight.shape[0]]
output_dtype = input.dtype output_dtype = input.dtype
...@@ -166,9 +174,12 @@ def apply_w8a8_block_fp8_linear( ...@@ -166,9 +174,12 @@ def apply_w8a8_block_fp8_linear(
weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0) weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0)
else: else:
use_cutlass = False use_cutlass = False
use_blaslt = False
if envs.VLLM_W8A8_BACKEND == 3:
use_blaslt = True
w8a8_blockscale_func = dispatch_w8a8_blockscale_func( w8a8_blockscale_func = dispatch_w8a8_blockscale_func(
use_cutlass, use_aiter_and_is_supported) use_cutlass, use_aiter_and_is_supported, use_blaslt)
if use_cutlass: if use_cutlass:
q_input, x_scale = per_token_group_quant_fp8( q_input, x_scale = per_token_group_quant_fp8(
input_2d, block_size[1], column_major_scales=use_cutlass) input_2d, block_size[1], column_major_scales=use_cutlass)
...@@ -197,6 +208,10 @@ def apply_w8a8_block_fp8_linear_fake( ...@@ -197,6 +208,10 @@ def apply_w8a8_block_fp8_linear_fake(
cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED, cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED,
use_aiter_and_is_supported: bool = False, use_aiter_and_is_supported: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
output_shape = []
if envs.VLLM_W8A8_BACKEND == 3:
output_shape = [*input.shape[:-1], weight.shape[-1]]
else:
output_shape = [*input.shape[:-1], weight.shape[0]] output_shape = [*input.shape[:-1], weight.shape[0]]
return torch.empty(output_shape, dtype=input.dtype, device=input.device) return torch.empty(output_shape, dtype=input.dtype, device=input.device)
...@@ -566,6 +581,30 @@ def get_w8a8_block_fp8_configs(N: int, K: int, block_n: int, ...@@ -566,6 +581,30 @@ def get_w8a8_block_fp8_configs(N: int, K: int, block_n: int,
return None return None
def hipblaslt_w8a8_block_fp8_matmul(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
block_size: list[int],
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
m, k = A.shape
_, n = B.shape
enum_block_size = BlockSize.block_128x128
if block_size[0] == 64:
enum_block_size = BlockSize.block_64x64
elif block_size[0] == 128:
enum_block_size = BlockSize.block_128x128
else:
print(f"[WARN] Unsupported block_size: {block_size}. Falling back to BlockSize.block_128x128")
_, d = quant_ops.hipblaslt_w8a8_blockwise_gemm(A, B, As, Bs,
m, n, k, 'NN', output_dtype,
enum_block_size, None)
return d
def w8a8_block_fp8_matmul( def w8a8_block_fp8_matmul(
A: torch.Tensor, A: torch.Tensor,
B: torch.Tensor, B: torch.Tensor,
......
...@@ -12,7 +12,10 @@ from vllm.platforms import current_platform ...@@ -12,7 +12,10 @@ from vllm.platforms import current_platform
from vllm.utils import W8a8GetCacheJSON from vllm.utils import W8a8GetCacheJSON
from lmslim.layers.gemm.int8_utils import per_token_quant_int8 from lmslim.layers.gemm.int8_utils import per_token_quant_int8
from lmslim.layers.gemm.fp8_utils import triton_scaled_mm_fp8 try:
from lmslim.layers.gemm.fp8_utils import triton_scaled_mm_fp8
except Exception:
print("INFO: Please updata lmslim if you want to use fp8_utils.\n")
# Input scaling factors are no longer optional in _scaled_mm starting # Input scaling factors are no longer optional in _scaled_mm starting
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale # from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
TORCH_DEVICE_IDENTITY = None TORCH_DEVICE_IDENTITY = None
......
...@@ -232,6 +232,11 @@ def get_model_architecture( ...@@ -232,6 +232,11 @@ def get_model_architecture(
'ChatGLMModel', 'Glm4ForCausalLM', 'ChatGLMForConditionalGeneration', 'BaichuanForCausalLM', 'BloomForCausalLM', 'TeleChat2ForCausalLM', 'MixtralForCausalLM', 'FalconForCausalLM', 'ChatGLMModel', 'Glm4ForCausalLM', 'ChatGLMForConditionalGeneration', 'BaichuanForCausalLM', 'BloomForCausalLM', 'TeleChat2ForCausalLM', 'MixtralForCausalLM', 'FalconForCausalLM',
'MedusaModel', 'MLPSpeculatorPreTrainedModel', 'DeepseekV2ForCausalLM', 'DeepseekV3ForCausalLM', 'DeepSeekMTPModel'] 'MedusaModel', 'MLPSpeculatorPreTrainedModel', 'DeepseekV2ForCausalLM', 'DeepseekV3ForCausalLM', 'DeepSeekMTPModel']
if any(arch in architectures for arch in support_nn_architectures): if any(arch in architectures for arch in support_nn_architectures):
#针对使用dtype为fp16的情况的量化默认关闭"VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD"
if model_config.quantization in {"awq", "awq_marlin", "moe_wna16"}:
if not envs.is_set("VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD"):
os.environ['VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD'] = '0'
if not envs.VLLM_USE_NN: if not envs.VLLM_USE_NN:
if os.getenv('LLAMA_NN') != '0': if os.getenv('LLAMA_NN') != '0':
if (architectures == ['QWenLMHeadModel'] or architectures == ['ChatGLMModel'] ) and visions != []: if (architectures == ['QWenLMHeadModel'] or architectures == ['ChatGLMModel'] ) and visions != []:
...@@ -258,6 +263,8 @@ def get_model_architecture( ...@@ -258,6 +263,8 @@ def get_model_architecture(
os.environ['VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT'] = '1' os.environ['VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT'] = '1'
if not envs.is_set("VLLM_SCHED_ENABLE_MINIMAL_INJECTION"): if not envs.is_set("VLLM_SCHED_ENABLE_MINIMAL_INJECTION"):
os.environ['VLLM_SCHED_ENABLE_MINIMAL_INJECTION'] = '1' os.environ['VLLM_SCHED_ENABLE_MINIMAL_INJECTION'] = '1'
if not envs.is_set("VLLM_USE_FUSED_FILL_RMS_CAT"):
os.environ['VLLM_USE_FUSED_FILL_RMS_CAT'] = '1'
if model_config.quantization in {"slimquant_w4a8", "slimquant_w4a8_marlin", "slimquant_compressed_tensors_marlin", "compressed-tensors"}: if model_config.quantization in {"slimquant_w4a8", "slimquant_w4a8_marlin", "slimquant_compressed_tensors_marlin", "compressed-tensors"}:
if not envs.is_set("USE_FUSED_RMS_QUANT"): if not envs.is_set("USE_FUSED_RMS_QUANT"):
os.environ['USE_FUSED_RMS_QUANT'] = '1' os.environ['USE_FUSED_RMS_QUANT'] = '1'
...@@ -302,6 +309,8 @@ def get_model_architecture( ...@@ -302,6 +309,8 @@ def get_model_architecture(
os.environ['VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT'] = '1' os.environ['VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT'] = '1'
if not envs.is_set("VLLM_SCHED_ENABLE_MINIMAL_INJECTION"): if not envs.is_set("VLLM_SCHED_ENABLE_MINIMAL_INJECTION"):
os.environ['VLLM_SCHED_ENABLE_MINIMAL_INJECTION'] = '1' os.environ['VLLM_SCHED_ENABLE_MINIMAL_INJECTION'] = '1'
if not envs.is_set("VLLM_USE_FUSED_FILL_RMS_CAT"):
os.environ['VLLM_USE_FUSED_FILL_RMS_CAT'] = '1'
if model_config.quantization in {"slimquant_w4a8", "slimquant_w4a8_marlin", "slimquant_compressed_tensors_marlin", "compressed-tensors"}: if model_config.quantization in {"slimquant_w4a8", "slimquant_w4a8_marlin", "slimquant_compressed_tensors_marlin", "compressed-tensors"}:
if not envs.is_set("USE_FUSED_RMS_QUANT"): if not envs.is_set("USE_FUSED_RMS_QUANT"):
os.environ['USE_FUSED_RMS_QUANT'] = '1' os.environ['USE_FUSED_RMS_QUANT'] = '1'
......
...@@ -422,9 +422,12 @@ class DeepseekV2MoE(nn.Module): ...@@ -422,9 +422,12 @@ class DeepseekV2MoE(nn.Module):
# Fix FP16 overflow # Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details. # See DeepseekV2DecoderLayer for more details.
# fp16 mode not fused quant # fp16 mode not fused quant
if i_q is not None:
i_q=iqis[0]
i_s=iqis[1]
final_hidden_states = self.experts(hidden_states=hidden_states, final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits, router_logits=router_logits,
i_q=iqis[0], i_s=iqis[1]) i_q=i_q, i_s=i_s)
if shared_output is not None: if shared_output is not None:
if hidden_states.dtype != torch.float16: if hidden_states.dtype != torch.float16:
...@@ -466,9 +469,12 @@ class DeepseekV2MoE(nn.Module): ...@@ -466,9 +469,12 @@ class DeepseekV2MoE(nn.Module):
assert shared_output is not None assert shared_output is not None
final_hidden_states += (shared_output * (1. / self.routed_scaling_factor)) final_hidden_states += (shared_output * (1. / self.routed_scaling_factor))
else: else:
if i_q is not None:
i_q=iqis[0]
i_s=iqis[1]
final_hidden_states = self.experts(hidden_states=hidden_states, final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits, router_logits=router_logits,
i_q=iqis[0], i_s=iqis[1]) i_q=i_q, i_s=i_s)
if shared_output is not None: if shared_output is not None:
if hidden_states.dtype != torch.float16: if hidden_states.dtype != torch.float16:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment