Commit f1bc9890 authored by zhuwenwen's avatar zhuwenwen
Browse files

解决custom cudagraph模式需要拷贝的问题,需要配合dtk进行使用。

区分pcie和hglink custom allreduce的使用
vllm:export VLLM_CUSTOM_CACHE=1
dtk:export HIP_KERNEL_EVENT_SYSTENFENCE=1

set VLLM_USE_FUSED_RMS_ROPE=1
add SUPPORT_MOE_MARLIN_W16A16 to use moe marlin on bw
support fa kvcache fp8 (todo: add VLLM_USE_QUERY_QUANT to not use q quant)
update moe_align_block_size
parent f06d1125
......@@ -490,6 +490,7 @@ class CustomAllreduce {
std::map<IPC_KEY, char*> ipc_handles_;
uint32_t** dev_curr_hdp_reg;
hipEvent_t stopEvent;
/**
* Signals are an array of ipc-enabled buffers from all ranks.
* For each of the buffer, the layout is as follows:
......@@ -518,6 +519,7 @@ class CustomAllreduce {
hipDeviceGetAttribute((int*)&dev_curr_hdp_reg[i], hipDeviceAttributeHdpMemFlushCntl, i);
}
}
cudaEventCreate(&stopEvent);
}
char* open_ipc_handle(const void* ipc_handle) {
......@@ -642,10 +644,23 @@ class CustomAllreduce {
size /= d;
auto bytes = size * sizeof(typename packed_t<T>::P);
int blocks = std::min(block_limit, (size + threads - 1) / threads);
#define KL(ngpus, name) \
name<T, ngpus><<<blocks, threads, 0, stream>>>(ptrs, sg_, self_sg_, output, \
rank_, size, dev_curr_hdp_reg, world_size_) ;
// #define KL(ngpus, name) \
// name<T, ngpus><<<blocks, threads, 0, stream>>>(ptrs, sg_, self_sg_, output, \
// rank_, size, dev_curr_hdp_reg, world_size_) ;
#define KL(ngpus, name) \
{ \
void* kernelArgs[] = { \
&ptrs, &sg_, &self_sg_, &output, &rank_, &size \
}; \
hipExtLaunchKernel( \
(void*)name<T, ngpus>, \
blocks, threads, \
kernelArgs, 0, \
stream, nullptr, stopEvent, 0 \
); \
}
#define REDUCE_CASE(ngpus) \
case ngpus: { \
if (world_size_ == 2) { \
......@@ -739,9 +754,22 @@ class CustomAllreduce {
}
}
// #define KL(ngpus, name) \
// name<T, ngpus><<<blocks, threads, 0, stream>>>(ptrs, sg_, self_sg_, output, \
// rank_, size);
#define KL(ngpus, name) \
name<T, ngpus><<<blocks, threads, 0, stream>>>(ptrs, sg_, self_sg_, output, \
rank_, size);
{ \
void* kernelArgs[] = { \
&ptrs, &sg_, &self_sg_, &output, &rank_, &size \
}; \
hipExtLaunchKernel( \
(void*)name<T, ngpus>, \
blocks, threads, \
kernelArgs, 0, \
stream, nullptr, stopEvent, 0 \
); \
}
#define REDUCE_CASE(ngpus) \
case ngpus: { \
if (force_1stage) { \
......@@ -784,6 +812,7 @@ class CustomAllreduce {
CUDACHECK(cudaIpcCloseMemHandle(ptr));
}
cudaFree(dev_curr_hdp_reg);
cudaEventDestroy(stopEvent);
}
};
......
This diff is collapsed.
......@@ -281,11 +281,13 @@ class Attention(nn.Module, AttentionLayerBase):
# for attn backends supporting query quantization
self.query_quant = None
if (
self.kv_cache_dtype.startswith("fp8")
and self.impl.supports_quant_query_input
):
self.query_quant = QuantFP8(static=True, group_shape=GroupShape.PER_TENSOR)
# @TODO
if envs.VLLM_USE_QUERY_QUANT:
if (
self.kv_cache_dtype.startswith("fp8")
and self.impl.supports_quant_query_input
):
self.query_quant = QuantFP8(static=True, group_shape=GroupShape.PER_TENSOR)
def forward(
self,
......
......@@ -3,6 +3,7 @@
from vllm.logger import init_logger
from vllm.platforms import current_platform
import torch
logger = init_logger(__name__)
......@@ -86,6 +87,8 @@ def get_flash_attn_version(requires_alibi: bool = False) -> int | None:
def flash_attn_supports_fp8() -> bool:
if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938":
return True
return (
get_flash_attn_version() == 3
and current_platform.get_device_capability().major == 9
......
......@@ -285,7 +285,10 @@ class CustomAllreduce:
return None
if self._IS_CAPTURING:
if torch.cuda.is_current_stream_capturing():
return self.all_reduce(input, registered=False)
if envs.VLLM_CUSTOM_CACHE:
return self.all_reduce(input, registered=True)
else:
return self.all_reduce(input, registered=False)
else:
# If warm up, mimic the allocation pattern since custom
# allreduce is out-of-place.
......
......@@ -248,12 +248,14 @@ if TYPE_CHECKING:
VLLM_OPTEST_URLS_PORT: int | None = None
VLLM_OPTEST_MODELS_PATH: str = ""
VLLM_USE_TRITON_PREFIX_FLASH_ATTN: bool = False
VLLM_USE_QUERY_QUANT: bool = False
VLLM_USE_FLASH_MLA: bool = False
VLLM_USE_OPT_OP: bool = False
VLLM_USE_TC_PAGED_ATTN: bool = False
VLLM_USE_PA_PRINT_PARAM: bool = False
VLLM_SPEC_DECODE_EAGER: bool = False
VLLM_PCIE_USE_CUSTOM_ALLREDUCE: bool = False
VLLM_CUSTOM_CACHE: bool = False
VLLM_CUSTOM_ALLREDUCE_SUPPORTED_WORLDSIZE_MAX: int = 16
VLLM_ENFORCE_EAGER_BS_THRESHOLD: int | None = None
VLLM_HAS_CONTEXT_DEFAULT: bool = False
......@@ -1622,6 +1624,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda: (os.environ.get("VLLM_USE_TRITON_PREFIX_FLASH_ATTN", "False").lower() in
("true", "1")),
# flag to control if vllm should use q quant
"VLLM_USE_QUERY_QUANT":
lambda: (os.environ.get("VLLM_USE_QUERY_QUANT", "False").lower() in
("true", "1")),
# If set, vLLM will use FLASH MLA attention optimizations.
"VLLM_USE_FLASH_MLA":
lambda: bool(int(os.getenv("VLLM_USE_FLASH_MLA", "1"))),
......@@ -1649,6 +1656,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_PCIE_USE_CUSTOM_ALLREDUCE":
lambda: bool(int(os.environ.get("VLLM_PCIE_USE_CUSTOM_ALLREDUCE", "1"))),
# flag to control vllm to use optimized kernels
"VLLM_CUSTOM_CACHE":
lambda: bool(int(os.environ.get("VLLM_CUSTOM_CACHE", "1"))),
# flag to control vllm to use optimized kernels
"VLLM_CUSTOM_ALLREDUCE_SUPPORTED_WORLDSIZE_MAX":
lambda: int(os.getenv("VLLM_CUSTOM_ALLREDUCE_SUPPORTED_WORLDSIZE_MAX", "16")),
......@@ -1750,7 +1761,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
("true", "1")),
# vLLM will use fused RMS + RoPE kernel
"VLLM_USE_FUSED_RMS_ROPE":
lambda: (os.environ.get("VLLM_USE_FUSED_RMS_ROPE", "False").lower() in
lambda: (os.environ.get("VLLM_USE_FUSED_RMS_ROPE", "True").lower() in
("true", "1")),
# vLLM will use Marlin W16A16 kernel for MoE experts
"VLLM_USE_MARLIN_W16A16_MOE":
......
......@@ -115,7 +115,7 @@ def moe_align_block_size(
expert_map = expert_map,
expert_mask = expert_mask,
num_local_tokens = None,
Is_fuse_fill = False,
Is_fuse_fill = True,
)
else:
if envs.VLLM_USE_LIGHTOP_MOE_ALIGN:
......@@ -130,7 +130,7 @@ def moe_align_block_size(
expert_map = None,
expert_mask = None,
num_local_tokens = None,
Is_fuse_fill = False,
Is_fuse_fill = True,
)
else:
ops.moe_align_block_size(
......
......@@ -1153,7 +1153,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
self,
layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
router_logits: torch.Tensor,**_,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
if layer.enable_eplb:
......
......@@ -202,11 +202,11 @@ def _get_model_architecture(model_config: ModelConfig) -> tuple[type[nn.Module],
# if not envs.is_set("USE_FUSED_SILU_MUL_QUANT"):
# os.environ['USE_FUSED_SILU_MUL_QUANT'] = '1'
else:
if not envs.is_set("VLLM_USE_PD_SPLIT"):
os.environ['VLLM_USE_PD_SPLIT'] = '1'
# if not envs.is_set("VLLM_USE_PD_SPLIT"):
# os.environ['VLLM_USE_PD_SPLIT'] = '1'
if architectures in [['Qwen3MoeForCausalLM']]:
# if not envs.is_set("VLLM_USE_LIGHTOP_MOE_ALIGN"):
# os.environ['VLLM_USE_LIGHTOP_MOE_ALIGN'] = '1'
if not envs.is_set("VLLM_USE_LIGHTOP_MOE_ALIGN"):
os.environ['VLLM_USE_LIGHTOP_MOE_ALIGN'] = '1'
if not envs.is_set("VLLM_USE_LIGHTOP_MOE_SUM"):
os.environ['VLLM_USE_LIGHTOP_MOE_SUM'] = '1'
if not envs.is_set("VLLM_USE_FUSE_SILU_AND_MUL"):
......@@ -234,11 +234,11 @@ def _get_model_architecture(model_config: ModelConfig) -> tuple[type[nn.Module],
# if not envs.is_set("USE_FUSED_SILU_MUL_QUANT"):
# os.environ['USE_FUSED_SILU_MUL_QUANT'] = '1'
else:
if not envs.is_set("VLLM_USE_PD_SPLIT"):
os.environ['VLLM_USE_PD_SPLIT'] = '1'
# if not envs.is_set("VLLM_USE_PD_SPLIT"):
# os.environ['VLLM_USE_PD_SPLIT'] = '1'
if architectures in [['Qwen3MoeForCausalLM']]:
# if not envs.is_set("VLLM_USE_LIGHTOP_MOE_ALIGN"):
# os.environ['VLLM_USE_LIGHTOP_MOE_ALIGN'] = '1'
if not envs.is_set("VLLM_USE_LIGHTOP_MOE_ALIGN"):
os.environ['VLLM_USE_LIGHTOP_MOE_ALIGN'] = '1'
if not envs.is_set("VLLM_USE_LIGHTOP_MOE_SUM"):
os.environ['VLLM_USE_LIGHTOP_MOE_SUM'] = '1'
if not envs.is_set("VLLM_USE_FUSE_SILU_AND_MUL"):
......
......@@ -15,6 +15,12 @@ from vllm.utils.torch_utils import cuda_device_count_stateless
from .interface import DeviceCapability, Platform, PlatformEnum
from vllm.utils import SUPPORT_MOE_MARLIN_W16A16
if SUPPORT_MOE_MARLIN_W16A16:
os.environ['VLLM_USE_MARLIN_W16A16_MOE'] = '1'
os.environ['MOE_NN'] = '0'
if TYPE_CHECKING:
from vllm.attention.selector import AttentionSelectorConfig
from vllm.config import VllmConfig
......
......@@ -17,6 +17,9 @@ _DEPRECATED_MAPPINGS = {
"get_open_port": "network_utils",
}
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
SUPPORT_MOE_MARLIN_W16A16 = any(arch in GPU_ARCH for arch in ["gfx936"])
def __getattr__(name: str) -> Any: # noqa: D401 - short deprecation docstring
"""Module-level getattr to handle deprecated utilities."""
......
......@@ -183,6 +183,8 @@ class FlashAttentionBackend(AttentionBackend):
def get_fp8_dtype_for_flashattn(kv_cache_dtype: str) -> torch.dtype:
if kv_cache_dtype in ("fp8", "fp8_e4m3"):
return torch.float8_e4m3fn
elif kv_cache_dtype in ("fp8_e5m2"):
return torch.float8_e5m2
else:
raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment