Unverified Commit addac0e6 authored by Luka Govedič's avatar Luka Govedič Committed by GitHub
Browse files

[torch.compile] Enable AR+rms fusion by default available for `-O2` (#34299)


Signed-off-by: default avatarLuka Govedič <lgovedic@redhat.com>
parent 675a22ed
...@@ -115,7 +115,7 @@ class PassConfig: ...@@ -115,7 +115,7 @@ class PassConfig:
"""Fuse the custom SiluMul + quant ops.""" """Fuse the custom SiluMul + quant ops."""
fuse_attn_quant: bool = Field(default=None) fuse_attn_quant: bool = Field(default=None)
"""Fuse the custom attention + quant ops.""" """Fuse the custom attention + quant ops."""
eliminate_noops: bool = Field(default=None) eliminate_noops: bool = Field(default=True)
"""Eliminate no-op ops.""" """Eliminate no-op ops."""
enable_sp: bool = Field(default=None) enable_sp: bool = Field(default=None)
"""Enable sequence parallelism.""" """Enable sequence parallelism."""
...@@ -194,7 +194,6 @@ class PassConfig: ...@@ -194,7 +194,6 @@ class PassConfig:
"fuse_norm_quant", "fuse_norm_quant",
"fuse_act_quant", "fuse_act_quant",
"fuse_attn_quant", "fuse_attn_quant",
"eliminate_noops",
"enable_sp", "enable_sp",
"fuse_gemm_comms", "fuse_gemm_comms",
"fuse_allreduce_rms", "fuse_allreduce_rms",
......
...@@ -102,6 +102,19 @@ def enable_act_fusion(cfg: "VllmConfig") -> bool: ...@@ -102,6 +102,19 @@ def enable_act_fusion(cfg: "VllmConfig") -> bool:
) or cfg.compilation_config.is_custom_op_enabled("quant_fp8") ) or cfg.compilation_config.is_custom_op_enabled("quant_fp8")
def enable_allreduce_rms_fusion(cfg: "VllmConfig") -> bool:
"""Enable if TP > 1 and Hopper+ and flashinfer installed."""
from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer
return (
cfg.parallel_config.tensor_parallel_size > 1
and current_platform.is_cuda()
and current_platform.has_device_capability(90)
and has_flashinfer()
)
def enable_norm_pad_fusion(cfg: "VllmConfig") -> bool: def enable_norm_pad_fusion(cfg: "VllmConfig") -> bool:
"""Enable if using AITER RMSNorm and AITER Triton GEMMs """Enable if using AITER RMSNorm and AITER Triton GEMMs
and hidden size is 2880 i.e. gpt-oss; otherwise Inductor handles fusion.""" and hidden size is 2880 i.e. gpt-oss; otherwise Inductor handles fusion."""
...@@ -118,7 +131,6 @@ def enable_norm_pad_fusion(cfg: "VllmConfig") -> bool: ...@@ -118,7 +131,6 @@ def enable_norm_pad_fusion(cfg: "VllmConfig") -> bool:
OPTIMIZATION_LEVEL_00 = { OPTIMIZATION_LEVEL_00 = {
"compilation_config": { "compilation_config": {
"pass_config": { "pass_config": {
"eliminate_noops": False,
"fuse_norm_quant": False, "fuse_norm_quant": False,
"fuse_act_quant": False, "fuse_act_quant": False,
"fuse_allreduce_rms": False, "fuse_allreduce_rms": False,
...@@ -137,7 +149,6 @@ OPTIMIZATION_LEVEL_00 = { ...@@ -137,7 +149,6 @@ OPTIMIZATION_LEVEL_00 = {
OPTIMIZATION_LEVEL_01 = { OPTIMIZATION_LEVEL_01 = {
"compilation_config": { "compilation_config": {
"pass_config": { "pass_config": {
"eliminate_noops": True,
"fuse_norm_quant": enable_norm_fusion, "fuse_norm_quant": enable_norm_fusion,
"fuse_act_quant": enable_act_fusion, "fuse_act_quant": enable_act_fusion,
"fuse_allreduce_rms": False, "fuse_allreduce_rms": False,
...@@ -156,10 +167,9 @@ OPTIMIZATION_LEVEL_01 = { ...@@ -156,10 +167,9 @@ OPTIMIZATION_LEVEL_01 = {
OPTIMIZATION_LEVEL_02 = { OPTIMIZATION_LEVEL_02 = {
"compilation_config": { "compilation_config": {
"pass_config": { "pass_config": {
"eliminate_noops": True,
"fuse_norm_quant": enable_norm_fusion, "fuse_norm_quant": enable_norm_fusion,
"fuse_act_quant": enable_act_fusion, "fuse_act_quant": enable_act_fusion,
"fuse_allreduce_rms": False, "fuse_allreduce_rms": enable_allreduce_rms_fusion,
"fuse_attn_quant": IS_QUANTIZED, "fuse_attn_quant": IS_QUANTIZED,
"enable_sp": IS_DENSE, "enable_sp": IS_DENSE,
"fuse_gemm_comms": IS_DENSE, "fuse_gemm_comms": IS_DENSE,
...@@ -175,10 +185,9 @@ OPTIMIZATION_LEVEL_02 = { ...@@ -175,10 +185,9 @@ OPTIMIZATION_LEVEL_02 = {
OPTIMIZATION_LEVEL_03 = { OPTIMIZATION_LEVEL_03 = {
"compilation_config": { "compilation_config": {
"pass_config": { "pass_config": {
"eliminate_noops": True,
"fuse_norm_quant": enable_norm_fusion, "fuse_norm_quant": enable_norm_fusion,
"fuse_act_quant": enable_act_fusion, "fuse_act_quant": enable_act_fusion,
"fuse_allreduce_rms": False, "fuse_allreduce_rms": enable_allreduce_rms_fusion,
"fuse_attn_quant": IS_QUANTIZED, "fuse_attn_quant": IS_QUANTIZED,
"enable_sp": IS_DENSE, "enable_sp": IS_DENSE,
"fuse_gemm_comms": IS_DENSE, "fuse_gemm_comms": IS_DENSE,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment