"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "814133ec9cdef62404b8731934fc1ce229d0d948"
Unverified Commit 8e64140e authored by Xiaoyu Zhang's avatar Xiaoyu Zhang Committed by GitHub
Browse files

[b200] support trt-llm allreduce fuse rms_norm_add kernel (#7621)

parent 82f021e2
...@@ -32,8 +32,13 @@ from sglang.srt.layers.dp_attention import ( ...@@ -32,8 +32,13 @@ from sglang.srt.layers.dp_attention import (
get_attention_tp_rank, get_attention_tp_rank,
get_attention_tp_size, get_attention_tp_size,
) )
from sglang.srt.layers.utils import is_sm100_supported
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import is_cuda, is_flashinfer_available
_is_flashinfer_available = is_flashinfer_available()
_is_sm100_supported = is_cuda() and is_sm100_supported()
class ScatterMode(Enum): class ScatterMode(Enum):
...@@ -397,8 +402,19 @@ class CommunicateWithAllReduceAndLayerNormFn: ...@@ -397,8 +402,19 @@ class CommunicateWithAllReduceAndLayerNormFn:
if hidden_states.shape[0] != 0: if hidden_states.shape[0] != 0:
hidden_states = layernorm(hidden_states) hidden_states = layernorm(hidden_states)
else: else:
hidden_states = tensor_model_parallel_all_reduce(hidden_states) if (
hidden_states, residual = layernorm(hidden_states, residual) _is_sm100_supported
and _is_flashinfer_available
and hasattr(layernorm, "forward_with_allreduce_fusion")
and global_server_args_dict["enable_flashinfer_allreduce_fusion"]
and hidden_states.shape[0] <= 1024
):
hidden_states, residual = layernorm.forward_with_allreduce_fusion(
hidden_states, residual
)
else:
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
hidden_states, residual = layernorm(hidden_states, residual)
return hidden_states, residual return hidden_states, residual
@staticmethod @staticmethod
......
import logging
from typing import Tuple
import torch
import torch.distributed as dist
from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.utils import is_flashinfer_available
logger = logging.getLogger(__name__)
_flashinfer_comm = None
_workspace_manager = None
if is_flashinfer_available():
try:
import flashinfer.comm as comm
_flashinfer_comm = comm
except ImportError:
logger.warning(
"flashinfer.comm is not available, falling back to standard "
"implementation"
)
class FlashInferWorkspaceManager:
def __init__(self):
self.workspace_tensor = None
self.ipc_handles = None
self.world_size = None
self.rank = None
self.initialized = False
def initialize(
self,
world_size: int,
rank: int,
max_token_num: int,
hidden_dim: int,
group=None,
use_fp32_lamport: bool = False,
):
"""Initialize workspace"""
if self.initialized and self.world_size == world_size:
return
if _flashinfer_comm is None:
logger.warning(
"FlashInfer comm not available, skipping workspace " "initialization"
)
return
self.cleanup()
self.ipc_handles, self.workspace_tensor = (
comm.trtllm_create_ipc_workspace_for_all_reduce_fusion(
rank,
world_size,
max_token_num,
hidden_dim,
group=group,
use_fp32_lamport=use_fp32_lamport,
)
)
self.world_size = world_size
self.rank = rank
self.initialized = True
logger.info(
f"FlashInfer workspace initialized for rank {rank}, "
f"world_size {world_size}"
)
def cleanup(self):
"""Clean up workspace"""
if self.initialized and self.ipc_handles is not None:
try:
_flashinfer_comm.trtllm_destroy_ipc_workspace_for_all_reduce(
self.ipc_handles, group=dist.group.WORLD
)
except Exception as e:
logger.warning(f"Failed to cleanup FlashInfer workspace: {e}")
finally:
self.workspace_tensor = None
self.ipc_handles = None
self.initialized = False
_workspace_manager = FlashInferWorkspaceManager()
def ensure_workspace_initialized(
max_token_num: int = 1024, hidden_dim: int = 4096, use_fp32_lamport: bool = False
):
"""Ensure workspace is initialized"""
if not is_flashinfer_available() or _flashinfer_comm is None:
return False
world_size = get_tensor_model_parallel_world_size()
if world_size <= 1:
return False
rank = dist.get_rank()
if (
not _workspace_manager.initialized
or _workspace_manager.world_size != world_size
):
_workspace_manager.initialize(
world_size=world_size,
rank=rank,
max_token_num=max_token_num,
hidden_dim=hidden_dim,
use_fp32_lamport=use_fp32_lamport,
)
return _workspace_manager.initialized
def flashinfer_allreduce_add_rmsnorm(
input_tensor: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
eps: float = 1e-6,
max_token_num: int = 1024,
use_oneshot: bool = True,
trigger_completion_at_end: bool = False,
fp32_acc: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Use FlashInfer's fused allreduce + residual + RMS norm operation
Args:
input_tensor: Input tensor that needs allreduce
residual: Residual tensor
weight: RMS norm weight
eps: RMS norm epsilon
max_token_num: Maximum token number
use_oneshot: Whether to use oneshot mode
trigger_completion_at_end: Whether to trigger completion at end
fp32_acc: Whether to use fp32 precision
Returns:
Tuple[torch.Tensor, torch.Tensor]: (norm_output, residual_output)
"""
if not is_flashinfer_available() or _flashinfer_comm is None:
logger.debug(
"FlashInfer not available, falling back to standard " "implementation"
)
return None, None
world_size = get_tensor_model_parallel_world_size()
if world_size <= 1:
logger.debug("Single GPU, no need for allreduce fusion")
return None, None
if not ensure_workspace_initialized(
max_token_num=max_token_num,
hidden_dim=input_tensor.shape[-1],
use_fp32_lamport=(input_tensor.dtype == torch.float32),
):
logger.debug("FlashInfer workspace not available")
return None, None
token_num, hidden_dim = input_tensor.shape
residual_out = torch.empty_like(residual)
norm_out = torch.empty_like(input_tensor)
_flashinfer_comm.trtllm_allreduce_fusion(
allreduce_in=input_tensor,
world_size=world_size,
world_rank=dist.get_rank(),
token_num=token_num,
hidden_dim=hidden_dim,
workspace_ptrs=_workspace_manager.workspace_tensor,
launch_with_pdl=True,
use_oneshot=use_oneshot,
trigger_completion_at_end=trigger_completion_at_end,
fp32_acc=fp32_acc,
pattern_code=(_flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm),
allreduce_out=None,
residual_in=residual,
residual_out=residual_out,
norm_out=norm_out,
quant_out=None,
scale_out=None,
rms_gamma=weight,
rms_eps=eps,
scale_factor=None,
layout_code=None,
)
return norm_out, residual_out
def cleanup_flashinfer_workspace():
global _workspace_manager
if _workspace_manager is not None:
_workspace_manager.cleanup()
...@@ -163,6 +163,32 @@ class RMSNorm(CustomOp): ...@@ -163,6 +163,32 @@ class RMSNorm(CustomOp):
else: else:
return self.forward_native(x, residual) return self.forward_native(x, residual)
def forward_with_allreduce_fusion(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
Forward method with allreduce fusion, prioritizing flashinfer fused operations
"""
if residual is not None:
from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.flashinfer_comm_fusion import (
flashinfer_allreduce_add_rmsnorm,
)
if get_tensor_model_parallel_world_size() > 1:
fused_result = flashinfer_allreduce_add_rmsnorm(
input_tensor=x,
residual=residual,
weight=self.weight,
eps=self.variance_epsilon,
)
if fused_result[0] is not None:
return fused_result
return self.forward(x, residual)
class GemmaRMSNorm(CustomOp): class GemmaRMSNorm(CustomOp):
def __init__( def __init__(
......
...@@ -85,6 +85,7 @@ GLOBAL_SERVER_ARGS_KEYS = [ ...@@ -85,6 +85,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
"deepep_mode", "deepep_mode",
"enable_ep_moe", "enable_ep_moe",
"enable_flashinfer_moe", "enable_flashinfer_moe",
"enable_flashinfer_allreduce_fusion",
"moe_dense_tp_size", "moe_dense_tp_size",
"ep_dispatch_algorithm", "ep_dispatch_algorithm",
"deepep_config", "deepep_config",
......
...@@ -157,6 +157,7 @@ class ServerArgs: ...@@ -157,6 +157,7 @@ class ServerArgs:
enable_ep_moe: bool = False enable_ep_moe: bool = False
enable_deepep_moe: bool = False enable_deepep_moe: bool = False
enable_flashinfer_moe: bool = False enable_flashinfer_moe: bool = False
enable_flashinfer_allreduce_fusion: bool = False
deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto" deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto"
ep_num_redundant_experts: int = 0 ep_num_redundant_experts: int = 0
ep_dispatch_algorithm: Optional[Literal["static", "dynamic", "fake"]] = None ep_dispatch_algorithm: Optional[Literal["static", "dynamic", "fake"]] = None
...@@ -1206,6 +1207,11 @@ class ServerArgs: ...@@ -1206,6 +1207,11 @@ class ServerArgs:
action="store_true", action="store_true",
help="Enable FlashInfer CUTLASS MoE backend for modelopt_fp4 quant on Blackwell. Supports MoE-EP with --enable-ep-moe", help="Enable FlashInfer CUTLASS MoE backend for modelopt_fp4 quant on Blackwell. Supports MoE-EP with --enable-ep-moe",
) )
parser.add_argument(
"--enable-flashinfer-allreduce-fusion",
action="store_true",
help="Enable FlashInfer allreduce fusion for Add_RMSNorm.",
)
parser.add_argument( parser.add_argument(
"--enable-deepep-moe", "--enable-deepep-moe",
action="store_true", action="store_true",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment