Unverified Commit 6042e66c authored by Gregory Shtrasberg's avatar Gregory Shtrasberg Committed by GitHub
Browse files

[ROCm] Add extra step in config initialization to populate custom ops before...


[ROCm] Add extra step in config initialization to populate custom ops before compilation config init (#34848)
Signed-off-by: default avatarGregory Shtrasberg <Gregory.Shtrasberg@amd.com>
parent 9f9a675b
......@@ -809,6 +809,8 @@ class VllmConfig:
if "-quant_fp8" not in custom_ops:
custom_ops.append("+quant_fp8")
current_platform.apply_config_platform_defaults(self)
if self.compilation_config.mode is None:
if self.optimization_level > OptimizationLevel.O0:
self.compilation_config.mode = CompilationMode.VLLM_COMPILE
......
......@@ -393,6 +393,20 @@ class Platform:
"""
pass
@classmethod
def apply_config_platform_defaults(cls, vllm_config: "VllmConfig") -> None:
"""
Apply the platform-specific default values to the config.
This function is called during the initialization of global VllmConfig, after
parsing cli arguments.
It can modify the defaults of the config according to the platform. For example,
it can enable custom_ops based on the enabled features.
The config is passed by reference, so it can be modified in place.
"""
pass
@classmethod
def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
"""
......
......@@ -482,19 +482,61 @@ class RocmPlatform(Platform):
return device_props.total_memory
@classmethod
def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
def apply_config_platform_defaults(cls, vllm_config: "VllmConfig") -> None:
from vllm._aiter_ops import rocm_aiter_ops
from vllm.config.compilation import CUDAGraphMode
cache_config = vllm_config.cache_config
compilation_config = vllm_config.compilation_config
parallel_config = vllm_config.parallel_config
is_eager_execution = compilation_config == CUDAGraphMode.NONE
is_eager_execution = compilation_config.cudagraph_mode == CUDAGraphMode.NONE
use_aiter_fused_moe = rocm_aiter_ops.is_fused_moe_enabled()
use_aiter_rms_norm = rocm_aiter_ops.is_rmsnorm_enabled()
use_aiter_fp8_linear = rocm_aiter_ops.is_linear_fp8_enabled()
use_aiter_fused_se = rocm_aiter_ops.is_fusion_moe_shared_experts_enabled()
use_aiter_triton_rope = rocm_aiter_ops.is_triton_rotary_embed_enabled()
# Aiter rms norm perform best when CUDA Graph capture is enabled.
if (
use_aiter_rms_norm
and not is_eager_execution
and "-rms_norm" not in compilation_config.custom_ops
):
compilation_config.custom_ops.append("+rms_norm")
if use_aiter_fp8_linear and "-quant_fp8" not in compilation_config.custom_ops:
compilation_config.custom_ops.append("+quant_fp8")
if use_aiter_fused_se and "-grouped_topk" in compilation_config.custom_ops:
logger.warning_once(
"VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS is enabled, which "
"requires the 'grouped_topk' custom op. Overriding the "
"user-provided '-grouped_topk'."
)
compilation_config.custom_ops.remove("-grouped_topk")
# Ensure grouped_topk is always enabled when using AITER if
# its not disabled by user
if (
use_aiter_fused_moe
and "+grouped_topk" not in compilation_config.custom_ops
and "-grouped_topk" not in compilation_config.custom_ops
):
compilation_config.custom_ops.append("+grouped_topk")
# Enable rotary embedding when using AITER if its not disabled by user
if (
use_aiter_triton_rope
and "+rotary_embedding" not in compilation_config.custom_ops
and "-rotary_embedding" not in compilation_config.custom_ops
):
compilation_config.custom_ops.append("+rotary_embedding")
# Default dispatch to rocm's sparse_attn_indexer implementation
compilation_config.custom_ops.append("+sparse_attn_indexer")
@classmethod
def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
from vllm.config.compilation import CUDAGraphMode
cache_config = vllm_config.cache_config
compilation_config = vllm_config.compilation_config
parallel_config = vllm_config.parallel_config
if compilation_config.cudagraph_mode.has_full_cudagraphs():
# decode context parallel does not support full cudagraphs
......@@ -533,42 +575,6 @@ class RocmPlatform(Platform):
if parallel_config.worker_cls == "auto":
parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker"
# Aiter rms norm perform best when CUDA Graph capture is enabled.
if (
use_aiter_rms_norm
and not is_eager_execution
and "-rms_norm" not in compilation_config.custom_ops
):
compilation_config.custom_ops.append("+rms_norm")
if use_aiter_fp8_linear and "-quant_fp8" not in compilation_config.custom_ops:
compilation_config.custom_ops.append("+quant_fp8")
if use_aiter_fused_se and "-grouped_topk" in compilation_config.custom_ops:
logger.warning_once(
"VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS is enabled, which "
"requires the 'grouped_topk' custom op. Overriding the "
"user-provided '-grouped_topk'."
)
compilation_config.custom_ops.remove("-grouped_topk")
# Ensure grouped_topk is always enabled when using AITER if
# its not disabled by user
if (
use_aiter_fused_moe
and "+grouped_topk" not in compilation_config.custom_ops
and "-grouped_topk" not in compilation_config.custom_ops
):
compilation_config.custom_ops.append("+grouped_topk")
# Enable rotary embedding when using AITER if its not disabled by user
if (
use_aiter_triton_rope
and "+rotary_embedding" not in compilation_config.custom_ops
and "-rotary_embedding" not in compilation_config.custom_ops
):
compilation_config.custom_ops.append("+rotary_embedding")
# Default dispatch to rocm's sparse_attn_indexer implementation
compilation_config.custom_ops.append("+sparse_attn_indexer")
@classmethod
def verify_model_arch(cls, model_arch: str) -> None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment