"tests/vscode:/vscode.git/clone" did not exist on "328cbb2773d93d45d18dfd383d631de4fa276e69"
Commit f1bc9890 authored by zhuwenwen's avatar zhuwenwen
Browse files

解决custom cudagraph模式需要拷贝的问题,需要配合dtk进行使用。

区分pcie和hglink custom allreduce的使用
vllm:export VLLM_CUSTOM_CACHE=1
dtk:export HIP_KERNEL_EVENT_SYSTENFENCE=1

set VLLM_USE_FUSED_RMS_ROPE=1
add SUPPORT_MOE_MARLIN_W16A16 to use moe marlin on bw
support fa kvcache fp8 (todo: add VLLM_USE_QUERY_QUANT to not use q quant)
update moe_align_block_size
parent f06d1125
...@@ -490,6 +490,7 @@ class CustomAllreduce { ...@@ -490,6 +490,7 @@ class CustomAllreduce {
std::map<IPC_KEY, char*> ipc_handles_; std::map<IPC_KEY, char*> ipc_handles_;
uint32_t** dev_curr_hdp_reg; uint32_t** dev_curr_hdp_reg;
hipEvent_t stopEvent;
/** /**
* Signals are an array of ipc-enabled buffers from all ranks. * Signals are an array of ipc-enabled buffers from all ranks.
* For each of the buffer, the layout is as follows: * For each of the buffer, the layout is as follows:
...@@ -518,6 +519,7 @@ class CustomAllreduce { ...@@ -518,6 +519,7 @@ class CustomAllreduce {
hipDeviceGetAttribute((int*)&dev_curr_hdp_reg[i], hipDeviceAttributeHdpMemFlushCntl, i); hipDeviceGetAttribute((int*)&dev_curr_hdp_reg[i], hipDeviceAttributeHdpMemFlushCntl, i);
} }
} }
cudaEventCreate(&stopEvent);
} }
char* open_ipc_handle(const void* ipc_handle) { char* open_ipc_handle(const void* ipc_handle) {
...@@ -642,9 +644,22 @@ class CustomAllreduce { ...@@ -642,9 +644,22 @@ class CustomAllreduce {
size /= d; size /= d;
auto bytes = size * sizeof(typename packed_t<T>::P); auto bytes = size * sizeof(typename packed_t<T>::P);
int blocks = std::min(block_limit, (size + threads - 1) / threads); int blocks = std::min(block_limit, (size + threads - 1) / threads);
// #define KL(ngpus, name) \
// name<T, ngpus><<<blocks, threads, 0, stream>>>(ptrs, sg_, self_sg_, output, \
// rank_, size, dev_curr_hdp_reg, world_size_) ;
#define KL(ngpus, name) \ #define KL(ngpus, name) \
name<T, ngpus><<<blocks, threads, 0, stream>>>(ptrs, sg_, self_sg_, output, \ { \
rank_, size, dev_curr_hdp_reg, world_size_) ; void* kernelArgs[] = { \
&ptrs, &sg_, &self_sg_, &output, &rank_, &size \
}; \
hipExtLaunchKernel( \
(void*)name<T, ngpus>, \
blocks, threads, \
kernelArgs, 0, \
stream, nullptr, stopEvent, 0 \
); \
}
#define REDUCE_CASE(ngpus) \ #define REDUCE_CASE(ngpus) \
case ngpus: { \ case ngpus: { \
...@@ -739,9 +754,22 @@ class CustomAllreduce { ...@@ -739,9 +754,22 @@ class CustomAllreduce {
} }
} }
// #define KL(ngpus, name) \
// name<T, ngpus><<<blocks, threads, 0, stream>>>(ptrs, sg_, self_sg_, output, \
// rank_, size);
#define KL(ngpus, name) \ #define KL(ngpus, name) \
name<T, ngpus><<<blocks, threads, 0, stream>>>(ptrs, sg_, self_sg_, output, \ { \
rank_, size); void* kernelArgs[] = { \
&ptrs, &sg_, &self_sg_, &output, &rank_, &size \
}; \
hipExtLaunchKernel( \
(void*)name<T, ngpus>, \
blocks, threads, \
kernelArgs, 0, \
stream, nullptr, stopEvent, 0 \
); \
}
#define REDUCE_CASE(ngpus) \ #define REDUCE_CASE(ngpus) \
case ngpus: { \ case ngpus: { \
if (force_1stage) { \ if (force_1stage) { \
...@@ -784,6 +812,7 @@ class CustomAllreduce { ...@@ -784,6 +812,7 @@ class CustomAllreduce {
CUDACHECK(cudaIpcCloseMemHandle(ptr)); CUDACHECK(cudaIpcCloseMemHandle(ptr));
} }
cudaFree(dev_curr_hdp_reg); cudaFree(dev_curr_hdp_reg);
cudaEventDestroy(stopEvent);
} }
}; };
......
This diff is collapsed.
...@@ -281,6 +281,8 @@ class Attention(nn.Module, AttentionLayerBase): ...@@ -281,6 +281,8 @@ class Attention(nn.Module, AttentionLayerBase):
# for attn backends supporting query quantization # for attn backends supporting query quantization
self.query_quant = None self.query_quant = None
# @TODO
if envs.VLLM_USE_QUERY_QUANT:
if ( if (
self.kv_cache_dtype.startswith("fp8") self.kv_cache_dtype.startswith("fp8")
and self.impl.supports_quant_query_input and self.impl.supports_quant_query_input
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
import torch
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -86,6 +87,8 @@ def get_flash_attn_version(requires_alibi: bool = False) -> int | None: ...@@ -86,6 +87,8 @@ def get_flash_attn_version(requires_alibi: bool = False) -> int | None:
def flash_attn_supports_fp8() -> bool: def flash_attn_supports_fp8() -> bool:
if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938":
return True
return ( return (
get_flash_attn_version() == 3 get_flash_attn_version() == 3
and current_platform.get_device_capability().major == 9 and current_platform.get_device_capability().major == 9
......
...@@ -285,6 +285,9 @@ class CustomAllreduce: ...@@ -285,6 +285,9 @@ class CustomAllreduce:
return None return None
if self._IS_CAPTURING: if self._IS_CAPTURING:
if torch.cuda.is_current_stream_capturing(): if torch.cuda.is_current_stream_capturing():
if envs.VLLM_CUSTOM_CACHE:
return self.all_reduce(input, registered=True)
else:
return self.all_reduce(input, registered=False) return self.all_reduce(input, registered=False)
else: else:
# If warm up, mimic the allocation pattern since custom # If warm up, mimic the allocation pattern since custom
......
...@@ -248,12 +248,14 @@ if TYPE_CHECKING: ...@@ -248,12 +248,14 @@ if TYPE_CHECKING:
VLLM_OPTEST_URLS_PORT: int | None = None VLLM_OPTEST_URLS_PORT: int | None = None
VLLM_OPTEST_MODELS_PATH: str = "" VLLM_OPTEST_MODELS_PATH: str = ""
VLLM_USE_TRITON_PREFIX_FLASH_ATTN: bool = False VLLM_USE_TRITON_PREFIX_FLASH_ATTN: bool = False
VLLM_USE_QUERY_QUANT: bool = False
VLLM_USE_FLASH_MLA: bool = False VLLM_USE_FLASH_MLA: bool = False
VLLM_USE_OPT_OP: bool = False VLLM_USE_OPT_OP: bool = False
VLLM_USE_TC_PAGED_ATTN: bool = False VLLM_USE_TC_PAGED_ATTN: bool = False
VLLM_USE_PA_PRINT_PARAM: bool = False VLLM_USE_PA_PRINT_PARAM: bool = False
VLLM_SPEC_DECODE_EAGER: bool = False VLLM_SPEC_DECODE_EAGER: bool = False
VLLM_PCIE_USE_CUSTOM_ALLREDUCE: bool = False VLLM_PCIE_USE_CUSTOM_ALLREDUCE: bool = False
VLLM_CUSTOM_CACHE: bool = False
VLLM_CUSTOM_ALLREDUCE_SUPPORTED_WORLDSIZE_MAX: int = 16 VLLM_CUSTOM_ALLREDUCE_SUPPORTED_WORLDSIZE_MAX: int = 16
VLLM_ENFORCE_EAGER_BS_THRESHOLD: int | None = None VLLM_ENFORCE_EAGER_BS_THRESHOLD: int | None = None
VLLM_HAS_CONTEXT_DEFAULT: bool = False VLLM_HAS_CONTEXT_DEFAULT: bool = False
...@@ -1622,6 +1624,11 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1622,6 +1624,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda: (os.environ.get("VLLM_USE_TRITON_PREFIX_FLASH_ATTN", "False").lower() in lambda: (os.environ.get("VLLM_USE_TRITON_PREFIX_FLASH_ATTN", "False").lower() in
("true", "1")), ("true", "1")),
# flag to control if vllm should use q quant
"VLLM_USE_QUERY_QUANT":
lambda: (os.environ.get("VLLM_USE_QUERY_QUANT", "False").lower() in
("true", "1")),
# If set, vLLM will use FLASH MLA attention optimizations. # If set, vLLM will use FLASH MLA attention optimizations.
"VLLM_USE_FLASH_MLA": "VLLM_USE_FLASH_MLA":
lambda: bool(int(os.getenv("VLLM_USE_FLASH_MLA", "1"))), lambda: bool(int(os.getenv("VLLM_USE_FLASH_MLA", "1"))),
...@@ -1649,6 +1656,10 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1649,6 +1656,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_PCIE_USE_CUSTOM_ALLREDUCE": "VLLM_PCIE_USE_CUSTOM_ALLREDUCE":
lambda: bool(int(os.environ.get("VLLM_PCIE_USE_CUSTOM_ALLREDUCE", "1"))), lambda: bool(int(os.environ.get("VLLM_PCIE_USE_CUSTOM_ALLREDUCE", "1"))),
# flag to control vllm to use optimized kernels
"VLLM_CUSTOM_CACHE":
lambda: bool(int(os.environ.get("VLLM_CUSTOM_CACHE", "1"))),
# flag to control vllm to use optimized kernels # flag to control vllm to use optimized kernels
"VLLM_CUSTOM_ALLREDUCE_SUPPORTED_WORLDSIZE_MAX": "VLLM_CUSTOM_ALLREDUCE_SUPPORTED_WORLDSIZE_MAX":
lambda: int(os.getenv("VLLM_CUSTOM_ALLREDUCE_SUPPORTED_WORLDSIZE_MAX", "16")), lambda: int(os.getenv("VLLM_CUSTOM_ALLREDUCE_SUPPORTED_WORLDSIZE_MAX", "16")),
...@@ -1750,7 +1761,7 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1750,7 +1761,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
("true", "1")), ("true", "1")),
# vLLM will use fused RMS + RoPE kernel # vLLM will use fused RMS + RoPE kernel
"VLLM_USE_FUSED_RMS_ROPE": "VLLM_USE_FUSED_RMS_ROPE":
lambda: (os.environ.get("VLLM_USE_FUSED_RMS_ROPE", "False").lower() in lambda: (os.environ.get("VLLM_USE_FUSED_RMS_ROPE", "True").lower() in
("true", "1")), ("true", "1")),
# vLLM will use Marlin W16A16 kernel for MoE experts # vLLM will use Marlin W16A16 kernel for MoE experts
"VLLM_USE_MARLIN_W16A16_MOE": "VLLM_USE_MARLIN_W16A16_MOE":
......
...@@ -115,7 +115,7 @@ def moe_align_block_size( ...@@ -115,7 +115,7 @@ def moe_align_block_size(
expert_map = expert_map, expert_map = expert_map,
expert_mask = expert_mask, expert_mask = expert_mask,
num_local_tokens = None, num_local_tokens = None,
Is_fuse_fill = False, Is_fuse_fill = True,
) )
else: else:
if envs.VLLM_USE_LIGHTOP_MOE_ALIGN: if envs.VLLM_USE_LIGHTOP_MOE_ALIGN:
...@@ -130,7 +130,7 @@ def moe_align_block_size( ...@@ -130,7 +130,7 @@ def moe_align_block_size(
expert_map = None, expert_map = None,
expert_mask = None, expert_mask = None,
num_local_tokens = None, num_local_tokens = None,
Is_fuse_fill = False, Is_fuse_fill = True,
) )
else: else:
ops.moe_align_block_size( ops.moe_align_block_size(
......
...@@ -1153,7 +1153,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -1153,7 +1153,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
self, self,
layer: FusedMoE, layer: FusedMoE,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,**_,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
if layer.enable_eplb: if layer.enable_eplb:
......
...@@ -202,11 +202,11 @@ def _get_model_architecture(model_config: ModelConfig) -> tuple[type[nn.Module], ...@@ -202,11 +202,11 @@ def _get_model_architecture(model_config: ModelConfig) -> tuple[type[nn.Module],
# if not envs.is_set("USE_FUSED_SILU_MUL_QUANT"): # if not envs.is_set("USE_FUSED_SILU_MUL_QUANT"):
# os.environ['USE_FUSED_SILU_MUL_QUANT'] = '1' # os.environ['USE_FUSED_SILU_MUL_QUANT'] = '1'
else: else:
if not envs.is_set("VLLM_USE_PD_SPLIT"): # if not envs.is_set("VLLM_USE_PD_SPLIT"):
os.environ['VLLM_USE_PD_SPLIT'] = '1' # os.environ['VLLM_USE_PD_SPLIT'] = '1'
if architectures in [['Qwen3MoeForCausalLM']]: if architectures in [['Qwen3MoeForCausalLM']]:
# if not envs.is_set("VLLM_USE_LIGHTOP_MOE_ALIGN"): if not envs.is_set("VLLM_USE_LIGHTOP_MOE_ALIGN"):
# os.environ['VLLM_USE_LIGHTOP_MOE_ALIGN'] = '1' os.environ['VLLM_USE_LIGHTOP_MOE_ALIGN'] = '1'
if not envs.is_set("VLLM_USE_LIGHTOP_MOE_SUM"): if not envs.is_set("VLLM_USE_LIGHTOP_MOE_SUM"):
os.environ['VLLM_USE_LIGHTOP_MOE_SUM'] = '1' os.environ['VLLM_USE_LIGHTOP_MOE_SUM'] = '1'
if not envs.is_set("VLLM_USE_FUSE_SILU_AND_MUL"): if not envs.is_set("VLLM_USE_FUSE_SILU_AND_MUL"):
...@@ -234,11 +234,11 @@ def _get_model_architecture(model_config: ModelConfig) -> tuple[type[nn.Module], ...@@ -234,11 +234,11 @@ def _get_model_architecture(model_config: ModelConfig) -> tuple[type[nn.Module],
# if not envs.is_set("USE_FUSED_SILU_MUL_QUANT"): # if not envs.is_set("USE_FUSED_SILU_MUL_QUANT"):
# os.environ['USE_FUSED_SILU_MUL_QUANT'] = '1' # os.environ['USE_FUSED_SILU_MUL_QUANT'] = '1'
else: else:
if not envs.is_set("VLLM_USE_PD_SPLIT"): # if not envs.is_set("VLLM_USE_PD_SPLIT"):
os.environ['VLLM_USE_PD_SPLIT'] = '1' # os.environ['VLLM_USE_PD_SPLIT'] = '1'
if architectures in [['Qwen3MoeForCausalLM']]: if architectures in [['Qwen3MoeForCausalLM']]:
# if not envs.is_set("VLLM_USE_LIGHTOP_MOE_ALIGN"): if not envs.is_set("VLLM_USE_LIGHTOP_MOE_ALIGN"):
# os.environ['VLLM_USE_LIGHTOP_MOE_ALIGN'] = '1' os.environ['VLLM_USE_LIGHTOP_MOE_ALIGN'] = '1'
if not envs.is_set("VLLM_USE_LIGHTOP_MOE_SUM"): if not envs.is_set("VLLM_USE_LIGHTOP_MOE_SUM"):
os.environ['VLLM_USE_LIGHTOP_MOE_SUM'] = '1' os.environ['VLLM_USE_LIGHTOP_MOE_SUM'] = '1'
if not envs.is_set("VLLM_USE_FUSE_SILU_AND_MUL"): if not envs.is_set("VLLM_USE_FUSE_SILU_AND_MUL"):
......
...@@ -15,6 +15,12 @@ from vllm.utils.torch_utils import cuda_device_count_stateless ...@@ -15,6 +15,12 @@ from vllm.utils.torch_utils import cuda_device_count_stateless
from .interface import DeviceCapability, Platform, PlatformEnum from .interface import DeviceCapability, Platform, PlatformEnum
from vllm.utils import SUPPORT_MOE_MARLIN_W16A16
if SUPPORT_MOE_MARLIN_W16A16:
os.environ['VLLM_USE_MARLIN_W16A16_MOE'] = '1'
os.environ['MOE_NN'] = '0'
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.selector import AttentionSelectorConfig from vllm.attention.selector import AttentionSelectorConfig
from vllm.config import VllmConfig from vllm.config import VllmConfig
......
...@@ -17,6 +17,9 @@ _DEPRECATED_MAPPINGS = { ...@@ -17,6 +17,9 @@ _DEPRECATED_MAPPINGS = {
"get_open_port": "network_utils", "get_open_port": "network_utils",
} }
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
SUPPORT_MOE_MARLIN_W16A16 = any(arch in GPU_ARCH for arch in ["gfx936"])
def __getattr__(name: str) -> Any: # noqa: D401 - short deprecation docstring def __getattr__(name: str) -> Any: # noqa: D401 - short deprecation docstring
"""Module-level getattr to handle deprecated utilities.""" """Module-level getattr to handle deprecated utilities."""
......
...@@ -183,6 +183,8 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -183,6 +183,8 @@ class FlashAttentionBackend(AttentionBackend):
def get_fp8_dtype_for_flashattn(kv_cache_dtype: str) -> torch.dtype: def get_fp8_dtype_for_flashattn(kv_cache_dtype: str) -> torch.dtype:
if kv_cache_dtype in ("fp8", "fp8_e4m3"): if kv_cache_dtype in ("fp8", "fp8_e4m3"):
return torch.float8_e4m3fn return torch.float8_e4m3fn
elif kv_cache_dtype in ("fp8_e5m2"):
return torch.float8_e5m2
else: else:
raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}") raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment