Unverified Commit 2e7ab862 authored by Xiaoyu Zhang's avatar Xiaoyu Zhang Committed by GitHub
Browse files

Fix illegal memory in trtllm allreduce fusion (#7864)

parent 51ae4030
...@@ -402,12 +402,14 @@ class CommunicateWithAllReduceAndLayerNormFn: ...@@ -402,12 +402,14 @@ class CommunicateWithAllReduceAndLayerNormFn:
if hidden_states.shape[0] != 0: if hidden_states.shape[0] != 0:
hidden_states = layernorm(hidden_states) hidden_states = layernorm(hidden_states)
else: else:
# According to the discussion in https://github.com/flashinfer-ai/flashinfer/issues/1223#issuecomment-3047256465
# We set the max token num to 128 for allreduce fusion with min-latency case(use_oneshot=True).
if ( if (
_is_sm100_supported _is_sm100_supported
and _is_flashinfer_available and _is_flashinfer_available
and hasattr(layernorm, "forward_with_allreduce_fusion") and hasattr(layernorm, "forward_with_allreduce_fusion")
and global_server_args_dict["enable_flashinfer_allreduce_fusion"] and global_server_args_dict["enable_flashinfer_allreduce_fusion"]
and hidden_states.shape[0] <= 1024 and hidden_states.shape[0] <= 128
): ):
hidden_states, residual = layernorm.forward_with_allreduce_fusion( hidden_states, residual = layernorm.forward_with_allreduce_fusion(
hidden_states, residual hidden_states, residual
......
...@@ -92,7 +92,7 @@ _workspace_manager = FlashInferWorkspaceManager() ...@@ -92,7 +92,7 @@ _workspace_manager = FlashInferWorkspaceManager()
def ensure_workspace_initialized( def ensure_workspace_initialized(
max_token_num: int = 1024, hidden_dim: int = 4096, use_fp32_lamport: bool = False max_token_num: int = 128, hidden_dim: int = 4096, use_fp32_lamport: bool = False
): ):
"""Ensure workspace is initialized""" """Ensure workspace is initialized"""
if not is_flashinfer_available() or _flashinfer_comm is None: if not is_flashinfer_available() or _flashinfer_comm is None:
...@@ -119,12 +119,12 @@ def ensure_workspace_initialized( ...@@ -119,12 +119,12 @@ def ensure_workspace_initialized(
return _workspace_manager.initialized return _workspace_manager.initialized
def flashinfer_allreduce_add_rmsnorm( def flashinfer_allreduce_residual_rmsnorm(
input_tensor: torch.Tensor, input_tensor: torch.Tensor,
residual: torch.Tensor, residual: torch.Tensor,
weight: torch.Tensor, weight: torch.Tensor,
eps: float = 1e-6, eps: float = 1e-6,
max_token_num: int = 1024, max_token_num: int = 128,
use_oneshot: bool = True, use_oneshot: bool = True,
trigger_completion_at_end: bool = False, trigger_completion_at_end: bool = False,
fp32_acc: bool = False, fp32_acc: bool = False,
......
...@@ -174,11 +174,11 @@ class RMSNorm(CustomOp): ...@@ -174,11 +174,11 @@ class RMSNorm(CustomOp):
if residual is not None: if residual is not None:
from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.flashinfer_comm_fusion import ( from sglang.srt.layers.flashinfer_comm_fusion import (
flashinfer_allreduce_add_rmsnorm, flashinfer_allreduce_residual_rmsnorm,
) )
if get_tensor_model_parallel_world_size() > 1: if get_tensor_model_parallel_world_size() > 1:
fused_result = flashinfer_allreduce_add_rmsnorm( fused_result = flashinfer_allreduce_residual_rmsnorm(
input_tensor=x, input_tensor=x,
residual=residual, residual=residual,
weight=self.weight, weight=self.weight,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment