Unverified Commit 88fbc31b authored by strgrb's avatar strgrb Committed by GitHub
Browse files

Support trtllm_allreduce_fusion in flashinfer for cuda<12.8 (#9339)


Co-authored-by: default avatarZhang Kaihong <zhangkaihong.zkh@alibaba-inc.com>
parent 8f5b9910
...@@ -292,7 +292,6 @@ class LayerCommunicator: ...@@ -292,7 +292,6 @@ class LayerCommunicator:
(not self.is_last_layer) (not self.is_last_layer)
and (self._context.tp_size > 1) and (self._context.tp_size > 1)
and global_server_args_dict.get("enable_flashinfer_allreduce_fusion", False) and global_server_args_dict.get("enable_flashinfer_allreduce_fusion", False)
and _is_sm100_supported
and _is_flashinfer_available and _is_flashinfer_available
) )
......
...@@ -5,7 +5,11 @@ import torch ...@@ -5,7 +5,11 @@ import torch
import torch.distributed as dist import torch.distributed as dist
from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.utils import is_flashinfer_available from sglang.srt.utils import (
direct_register_custom_op,
is_flashinfer_available,
supports_custom_op,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -196,6 +200,30 @@ def flashinfer_allreduce_residual_rmsnorm( ...@@ -196,6 +200,30 @@ def flashinfer_allreduce_residual_rmsnorm(
return norm_out, residual_out return norm_out, residual_out
def fake_flashinfer_allreduce_residual_rmsnorm(
input_tensor: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
eps: float = 1e-6,
max_token_num: int = 2048,
use_oneshot: Optional[bool] = None,
trigger_completion_at_end: bool = False,
fp32_acc: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
residual_out = torch.empty_like(residual)
norm_out = torch.empty_like(input_tensor)
return norm_out, residual_out
if supports_custom_op():
direct_register_custom_op(
"flashinfer_allreduce_residual_rmsnorm",
flashinfer_allreduce_residual_rmsnorm,
mutates_args=["input_tensor", "residual", "weight"],
fake_impl=fake_flashinfer_allreduce_residual_rmsnorm,
)
def cleanup_flashinfer_workspace(): def cleanup_flashinfer_workspace():
global _workspace_manager global _workspace_manager
if _workspace_manager is not None: if _workspace_manager is not None:
......
...@@ -27,6 +27,7 @@ from sglang.srt.utils import ( ...@@ -27,6 +27,7 @@ from sglang.srt.utils import (
is_cuda, is_cuda,
is_hip, is_hip,
is_npu, is_npu,
supports_custom_op,
) )
_is_cuda = is_cuda() _is_cuda = is_cuda()
...@@ -202,8 +203,14 @@ class RMSNorm(CustomOp): ...@@ -202,8 +203,14 @@ class RMSNorm(CustomOp):
flashinfer_allreduce_residual_rmsnorm, flashinfer_allreduce_residual_rmsnorm,
) )
fused_op = (
torch.ops.sglang.flashinfer_allreduce_residual_rmsnorm
if supports_custom_op()
else flashinfer_allreduce_residual_rmsnorm
)
if get_tensor_model_parallel_world_size() > 1: if get_tensor_model_parallel_world_size() > 1:
fused_result = flashinfer_allreduce_residual_rmsnorm( fused_result = fused_op(
input_tensor=x, input_tensor=x,
residual=residual, residual=residual,
weight=self.weight, weight=self.weight,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment