Unverified Commit d4f123cc authored by Mohammad Miadh Angkad's avatar Mohammad Miadh Angkad Committed by GitHub
Browse files

[Kernel] FlashInfer: switch allreduce fusion to unified API (#33985)


Signed-off-by: default avatarMohammad Miadh Angkad <176301910+mmangkad@users.noreply.github.com>
parent cb62e86f
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
Benchmark for FlashInfer fused collective operations vs standard operations. Benchmark for FlashInfer fused collective operations vs standard operations.
This benchmark compares: This benchmark compares:
1. FlashInfer's trtllm_allreduce_fusion (fused allreduce + rmsnorm + optional quant) 1. FlashInfer's allreduce_fusion (fused allreduce + rmsnorm + optional quant)
2. Standard tensor_model_parallel_all_reduce + separate rmsnorm/quant operations 2. Standard tensor_model_parallel_all_reduce + separate rmsnorm/quant operations
Usage with torchrun: Usage with torchrun:
...@@ -24,7 +24,6 @@ import torch.distributed as dist # type: ignore ...@@ -24,7 +24,6 @@ import torch.distributed as dist # type: ignore
from vllm.config.vllm import CompilationConfig, VllmConfig, set_current_vllm_config from vllm.config.vllm import CompilationConfig, VllmConfig, set_current_vllm_config
from vllm.distributed import ( from vllm.distributed import (
get_tp_group,
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
) )
from vllm.distributed.parallel_state import ( from vllm.distributed.parallel_state import (
...@@ -52,11 +51,12 @@ logger = init_logger(__name__) ...@@ -52,11 +51,12 @@ logger = init_logger(__name__)
try: try:
import flashinfer.comm as flashinfer_comm # type: ignore import flashinfer.comm as flashinfer_comm # type: ignore
if not hasattr(flashinfer_comm, "trtllm_allreduce_fusion"): if not (
hasattr(flashinfer_comm, "allreduce_fusion")
and hasattr(flashinfer_comm, "create_allreduce_fusion_workspace")
):
flashinfer_comm = None flashinfer_comm = None
logger.warning( logger.warning("FlashInfer comm module found but missing allreduce_fusion API")
"FlashInfer comm module found but missing trtllm_allreduce_fusion"
)
except ImportError: except ImportError:
flashinfer_comm = None flashinfer_comm = None
logger.warning("FlashInfer not found, only benchmarking standard operations") logger.warning("FlashInfer not found, only benchmarking standard operations")
...@@ -75,7 +75,7 @@ _FI_MAX_SIZES = { ...@@ -75,7 +75,7 @@ _FI_MAX_SIZES = {
} }
# Global workspace tensor for FlashInfer # Global workspace tensor for FlashInfer
_FI_WORKSPACE_TENSOR = None _FI_WORKSPACE = None
def setup_flashinfer_workspace( def setup_flashinfer_workspace(
...@@ -83,10 +83,10 @@ def setup_flashinfer_workspace( ...@@ -83,10 +83,10 @@ def setup_flashinfer_workspace(
rank: int, rank: int,
hidden_dim: int, hidden_dim: int,
max_token_num: int, max_token_num: int,
use_fp32_lamport: bool = False, dtype: torch.dtype,
): ):
"""Setup FlashInfer workspace for fused allreduce operations.""" """Setup FlashInfer workspace for fused allreduce operations."""
global _FI_WORKSPACE_TENSOR global _FI_WORKSPACE
if flashinfer_comm is None: if flashinfer_comm is None:
return None, None return None, None
...@@ -96,33 +96,29 @@ def setup_flashinfer_workspace( ...@@ -96,33 +96,29 @@ def setup_flashinfer_workspace(
return None, None return None, None
try: try:
# Create IPC workspace workspace = flashinfer_comm.create_allreduce_fusion_workspace(
ipc_handles, workspace_tensor = ( backend="trtllm",
flashinfer_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion( world_size=world_size,
tp_rank=rank, rank=rank,
tp_size=world_size, max_token_num=max_token_num,
max_token_num=max_token_num, hidden_dim=hidden_dim,
hidden_dim=hidden_dim, dtype=dtype,
group=get_tp_group().device_group,
use_fp32_lamport=use_fp32_lamport,
)
) )
_FI_WORKSPACE_TENSOR = workspace_tensor _FI_WORKSPACE = workspace
return ipc_handles, workspace_tensor return workspace
except Exception as e: except Exception as e:
logger.error("Failed to setup FlashInfer workspace: %s", e) logger.error("Failed to setup FlashInfer workspace: %s", e)
return None, None return None
def cleanup_flashinfer_workspace(ipc_handles): def cleanup_flashinfer_workspace(workspace):
"""Cleanup FlashInfer workspace.""" """Cleanup FlashInfer workspace."""
if flashinfer_comm is None or ipc_handles is None: if flashinfer_comm is None or workspace is None:
return return
try: try:
group = get_tp_group().device_group workspace.destroy()
flashinfer_comm.trtllm_destroy_ipc_workspace_for_all_reduce(ipc_handles, group)
except Exception as e: except Exception as e:
logger.error("Failed to cleanup FlashInfer workspace: %s", e) logger.error("Failed to cleanup FlashInfer workspace: %s", e)
...@@ -132,25 +128,15 @@ class FlashInferFusedAllReduceParams: ...@@ -132,25 +128,15 @@ class FlashInferFusedAllReduceParams:
def __init__( def __init__(
self, self,
rank: int,
world_size: int,
use_fp32_lamport: bool = False,
max_token_num: int = 1024, max_token_num: int = 1024,
): ):
self.rank = rank
self.world_size = world_size
self.use_fp32_lamport = use_fp32_lamport
self.trigger_completion_at_end = True
self.launch_with_pdl = True self.launch_with_pdl = True
self.fp32_acc = True self.fp32_acc = True
self.max_token_num = max_token_num self.max_token_num = max_token_num
def get_trtllm_fused_allreduce_kwargs(self): def get_trtllm_fused_allreduce_kwargs(self):
return { return {
"world_rank": self.rank,
"world_size": self.world_size,
"launch_with_pdl": self.launch_with_pdl, "launch_with_pdl": self.launch_with_pdl,
"trigger_completion_at_end": self.trigger_completion_at_end,
"fp32_acc": self.fp32_acc, "fp32_acc": self.fp32_acc,
} }
...@@ -165,7 +151,7 @@ def flashinfer_fused_allreduce_rmsnorm( ...@@ -165,7 +151,7 @@ def flashinfer_fused_allreduce_rmsnorm(
norm_out: torch.Tensor | None = None, norm_out: torch.Tensor | None = None,
): ):
"""FlashInfer fused allreduce + rmsnorm operation.""" """FlashInfer fused allreduce + rmsnorm operation."""
if flashinfer_comm is None or _FI_WORKSPACE_TENSOR is None: if flashinfer_comm is None or _FI_WORKSPACE is None:
raise RuntimeError("FlashInfer not available or workspace not initialized") raise RuntimeError("FlashInfer not available or workspace not initialized")
if norm_out is None: if norm_out is None:
...@@ -174,18 +160,15 @@ def flashinfer_fused_allreduce_rmsnorm( ...@@ -174,18 +160,15 @@ def flashinfer_fused_allreduce_rmsnorm(
else: else:
residual_out = input_tensor residual_out = input_tensor
flashinfer_comm.trtllm_allreduce_fusion( flashinfer_comm.allreduce_fusion(
allreduce_in=input_tensor, input=input_tensor,
token_num=input_tensor.shape[0], workspace=_FI_WORKSPACE,
pattern=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm,
residual_in=residual, residual_in=residual,
residual_out=residual_out, residual_out=residual_out,
norm_out=norm_out, norm_out=norm_out,
rms_gamma=rms_gamma, rms_gamma=rms_gamma,
rms_eps=rms_eps, rms_eps=rms_eps,
hidden_dim=input_tensor.shape[-1],
workspace_ptrs=_FI_WORKSPACE_TENSOR,
pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm,
allreduce_out=None,
quant_out=None, quant_out=None,
scale_out=None, scale_out=None,
layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4, layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4,
...@@ -207,7 +190,7 @@ def flashinfer_fused_allreduce_rmsnorm_fp8_quant( ...@@ -207,7 +190,7 @@ def flashinfer_fused_allreduce_rmsnorm_fp8_quant(
quant_out: torch.Tensor | None = None, quant_out: torch.Tensor | None = None,
): ):
"""FlashInfer fused allreduce + rmsnorm + FP8 quantization.""" """FlashInfer fused allreduce + rmsnorm + FP8 quantization."""
if flashinfer_comm is None or _FI_WORKSPACE_TENSOR is None: if flashinfer_comm is None or _FI_WORKSPACE is None:
raise RuntimeError("FlashInfer not available or workspace not initialized") raise RuntimeError("FlashInfer not available or workspace not initialized")
if norm_out is None: if norm_out is None:
...@@ -216,18 +199,15 @@ def flashinfer_fused_allreduce_rmsnorm_fp8_quant( ...@@ -216,18 +199,15 @@ def flashinfer_fused_allreduce_rmsnorm_fp8_quant(
else: else:
residual_out = input_tensor residual_out = input_tensor
flashinfer_comm.trtllm_allreduce_fusion( flashinfer_comm.allreduce_fusion(
allreduce_in=input_tensor, input=input_tensor,
token_num=input_tensor.shape[0], workspace=_FI_WORKSPACE,
pattern=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP8Quant,
residual_in=residual, residual_in=residual,
residual_out=residual_out, residual_out=residual_out,
norm_out=norm_out, norm_out=norm_out,
rms_gamma=rms_gamma, rms_gamma=rms_gamma,
rms_eps=rms_eps, rms_eps=rms_eps,
hidden_dim=input_tensor.shape[-1],
workspace_ptrs=_FI_WORKSPACE_TENSOR,
pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP8Quant,
allreduce_out=None,
quant_out=quant_out, quant_out=quant_out,
scale_out=None, scale_out=None,
layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4, layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4,
...@@ -250,7 +230,7 @@ def flashinfer_fused_allreduce_rmsnorm_fp4_quant( ...@@ -250,7 +230,7 @@ def flashinfer_fused_allreduce_rmsnorm_fp4_quant(
norm_out: torch.Tensor | None = None, norm_out: torch.Tensor | None = None,
): ):
"""FlashInfer fused allreduce + rmsnorm + FP4 quantization.""" """FlashInfer fused allreduce + rmsnorm + FP4 quantization."""
if flashinfer_comm is None or _FI_WORKSPACE_TENSOR is None: if flashinfer_comm is None or _FI_WORKSPACE is None:
raise RuntimeError("FlashInfer not available or workspace not initialized") raise RuntimeError("FlashInfer not available or workspace not initialized")
if norm_out is None: if norm_out is None:
...@@ -259,18 +239,15 @@ def flashinfer_fused_allreduce_rmsnorm_fp4_quant( ...@@ -259,18 +239,15 @@ def flashinfer_fused_allreduce_rmsnorm_fp4_quant(
else: else:
residual_out = input_tensor residual_out = input_tensor
flashinfer_comm.trtllm_allreduce_fusion( flashinfer_comm.allreduce_fusion(
allreduce_in=input_tensor, input=input_tensor,
token_num=input_tensor.shape[0], workspace=_FI_WORKSPACE,
pattern=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP4Quant,
residual_in=residual, residual_in=residual,
residual_out=residual_out, residual_out=residual_out,
norm_out=norm_out, norm_out=norm_out,
rms_gamma=rms_gamma, rms_gamma=rms_gamma,
rms_eps=rms_eps, rms_eps=rms_eps,
hidden_dim=input_tensor.shape[-1],
workspace_ptrs=_FI_WORKSPACE_TENSOR,
pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP4Quant,
allreduce_out=None,
quant_out=quant_out, quant_out=quant_out,
scale_out=output_scale, scale_out=output_scale,
layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4, layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4,
...@@ -1040,23 +1017,31 @@ def main(): ...@@ -1040,23 +1017,31 @@ def main():
configs = list(itertools.product(args.num_tokens, dtypes, residual_options)) configs = list(itertools.product(args.num_tokens, dtypes, residual_options))
# Setup FlashInfer workspace if available # Setup FlashInfer workspace if available
ipc_handles = None workspace = None
allreduce_params = None allreduce_params = None
if flashinfer_comm is not None: if flashinfer_comm is not None:
# Use the largest hidden dimension for workspace setup # Use the largest hidden dimension for workspace setup
max_element_size = max(torch.finfo(dt).bits // 8 for dt in dtypes)
workspace_dtype = (
torch.float32
if max_element_size == 4
else (torch.bfloat16 if torch.bfloat16 in dtypes else torch.float16)
)
max_num_token = _FI_MAX_SIZES.get(world_size) // ( max_num_token = _FI_MAX_SIZES.get(world_size) // (
args.hidden_dim * world_size * 2 args.hidden_dim * max_element_size
) )
ipc_handles, workspace_tensor = setup_flashinfer_workspace( workspace = setup_flashinfer_workspace(
world_size, rank, args.hidden_dim, max_num_token world_size,
rank,
args.hidden_dim,
max_num_token,
dtype=workspace_dtype,
) )
if workspace_tensor is not None: if workspace is not None:
allreduce_params = FlashInferFusedAllReduceParams( allreduce_params = FlashInferFusedAllReduceParams(
rank=rank,
world_size=world_size,
max_token_num=max_num_token, max_token_num=max_num_token,
) )
...@@ -1119,8 +1104,8 @@ def main(): ...@@ -1119,8 +1104,8 @@ def main():
finally: finally:
# Cleanup # Cleanup
if ipc_handles is not None: if workspace is not None:
cleanup_flashinfer_workspace(ipc_handles) cleanup_flashinfer_workspace(workspace)
dist.barrier() dist.barrier()
......
...@@ -202,9 +202,10 @@ class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module): ...@@ -202,9 +202,10 @@ class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module):
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA") @pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA")
@pytest.mark.skipif( @pytest.mark.skipif(
not find_spec("flashinfer") not find_spec("flashinfer")
or not has_module_attribute("flashinfer.comm", "trtllm_allreduce_fusion"), or not has_module_attribute("flashinfer.comm", "allreduce_fusion")
or not has_module_attribute("flashinfer.comm", "create_allreduce_fusion_workspace"),
reason="flashinfer is not found or flashinfer " reason="flashinfer is not found or flashinfer "
"is not compiled with trtllm_allreduce_fusion", "is not compiled with allreduce_fusion",
) )
def test_all_reduce_fusion_pass_replace( def test_all_reduce_fusion_pass_replace(
test_model: torch.nn.Module, test_model: torch.nn.Module,
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
from importlib.util import find_spec from importlib.util import find_spec
from types import ModuleType from types import ModuleType
...@@ -36,7 +37,9 @@ if find_spec("flashinfer"): ...@@ -36,7 +37,9 @@ if find_spec("flashinfer"):
try: try:
import flashinfer.comm as _flashinfer_comm import flashinfer.comm as _flashinfer_comm
if hasattr(_flashinfer_comm, "trtllm_allreduce_fusion"): if hasattr(_flashinfer_comm, "allreduce_fusion") and hasattr(
_flashinfer_comm, "create_allreduce_fusion_workspace"
):
flashinfer_comm = _flashinfer_comm flashinfer_comm = _flashinfer_comm
except ImportError: except ImportError:
pass pass
...@@ -79,7 +82,7 @@ _FI_ALLREDUCE_ONE_SHOT_MAX_SIZES_MB: dict[int, dict[int, float]] = { ...@@ -79,7 +82,7 @@ _FI_ALLREDUCE_ONE_SHOT_MAX_SIZES_MB: dict[int, dict[int, float]] = {
if flashinfer_comm is not None: if flashinfer_comm is not None:
_FI_WORKSPACE_TENSOR = None _FI_WORKSPACE = None
MiB = 1024 * 1024 MiB = 1024 * 1024
def call_trtllm_fused_allreduce_norm( def call_trtllm_fused_allreduce_norm(
...@@ -87,10 +90,8 @@ if flashinfer_comm is not None: ...@@ -87,10 +90,8 @@ if flashinfer_comm is not None:
residual: torch.Tensor, residual: torch.Tensor,
rms_gamma: torch.Tensor, rms_gamma: torch.Tensor,
rms_eps: float, rms_eps: float,
world_rank: int,
world_size: int, world_size: int,
launch_with_pdl: bool, launch_with_pdl: bool,
trigger_completion_at_end: bool,
fp32_acc: bool, fp32_acc: bool,
max_token_num: int, max_token_num: int,
pattern_code: int, pattern_code: int,
...@@ -121,7 +122,7 @@ if flashinfer_comm is not None: ...@@ -121,7 +122,7 @@ if flashinfer_comm is not None:
max_one_shot_size is None or current_tensor_size <= max_one_shot_size * MiB max_one_shot_size is None or current_tensor_size <= max_one_shot_size * MiB
) )
assert _FI_WORKSPACE_TENSOR is not None, ( assert _FI_WORKSPACE is not None, (
"Flashinfer must be enabled when using flashinfer" "Flashinfer must be enabled when using flashinfer"
) )
if norm_out is None: if norm_out is None:
...@@ -134,24 +135,18 @@ if flashinfer_comm is not None: ...@@ -134,24 +135,18 @@ if flashinfer_comm is not None:
residual_out = allreduce_in residual_out = allreduce_in
# For the sizes that are smaller than the max size, # For the sizes that are smaller than the max size,
# we only use flashinfer one shot allreduce # we only use flashinfer one shot allreduce
flashinfer_comm.trtllm_allreduce_fusion( flashinfer_comm.allreduce_fusion(
allreduce_in=allreduce_in, input=allreduce_in,
token_num=allreduce_in.shape[0], workspace=_FI_WORKSPACE,
pattern=pattern_code,
residual_in=residual, residual_in=residual,
residual_out=residual_out, residual_out=residual_out,
norm_out=norm_out, norm_out=norm_out,
rms_gamma=rms_gamma, rms_gamma=rms_gamma,
rms_eps=rms_eps, rms_eps=rms_eps,
world_rank=world_rank,
world_size=world_size,
hidden_dim=allreduce_in.shape[-1],
workspace_ptrs=_FI_WORKSPACE_TENSOR,
launch_with_pdl=launch_with_pdl, launch_with_pdl=launch_with_pdl,
use_oneshot=use_oneshot, use_oneshot=use_oneshot,
trigger_completion_at_end=trigger_completion_at_end,
fp32_acc=fp32_acc, fp32_acc=fp32_acc,
pattern_code=pattern_code,
allreduce_out=None,
quant_out=quant_out, quant_out=quant_out,
scale_out=scale_out, scale_out=scale_out,
# in vllm we only support swizzled layout # in vllm we only support swizzled layout
...@@ -164,10 +159,8 @@ if flashinfer_comm is not None: ...@@ -164,10 +159,8 @@ if flashinfer_comm is not None:
residual: torch.Tensor, residual: torch.Tensor,
rms_gamma: torch.Tensor, rms_gamma: torch.Tensor,
rms_eps: float, rms_eps: float,
world_rank: int,
world_size: int, world_size: int,
launch_with_pdl: bool, launch_with_pdl: bool,
trigger_completion_at_end: bool,
fp32_acc: bool, fp32_acc: bool,
max_token_num: int, max_token_num: int,
pattern_code: int, pattern_code: int,
...@@ -200,25 +193,18 @@ class FlashInferFusedAllReduceParams: ...@@ -200,25 +193,18 @@ class FlashInferFusedAllReduceParams:
def __init__( def __init__(
self, self,
rank: int,
world_size: int, world_size: int,
use_fp32_lamport: bool = False,
max_token_num: int = 1024, max_token_num: int = 1024,
) -> None: ) -> None:
self.rank = rank
self.world_size = world_size self.world_size = world_size
self.use_fp32_lamport = use_fp32_lamport
self.trigger_completion_at_end = True
self.launch_with_pdl = True self.launch_with_pdl = True
self.fp32_acc = True self.fp32_acc = True
self.max_token_num = max_token_num self.max_token_num = max_token_num
def get_trtllm_fused_allreduce_kwargs(self) -> dict[str, bool | int]: def get_trtllm_fused_allreduce_kwargs(self) -> dict[str, bool | int]:
return { return {
"world_rank": self.rank,
"world_size": self.world_size, "world_size": self.world_size,
"launch_with_pdl": self.launch_with_pdl, "launch_with_pdl": self.launch_with_pdl,
"trigger_completion_at_end": self.trigger_completion_at_end,
"fp32_acc": self.fp32_acc, "fp32_acc": self.fp32_acc,
"max_token_num": self.max_token_num, "max_token_num": self.max_token_num,
} }
...@@ -712,7 +698,6 @@ class AllReduceFusionPass(VllmPatternMatcherPass): ...@@ -712,7 +698,6 @@ class AllReduceFusionPass(VllmPatternMatcherPass):
self.hidden_dim = config.model_config.get_hidden_size() self.hidden_dim = config.model_config.get_hidden_size()
self.group = get_tp_group().device_group self.group = get_tp_group().device_group
rank = get_tensor_model_parallel_rank() rank = get_tensor_model_parallel_rank()
use_fp32_lamport = self.model_dtype == torch.float32
if flashinfer_comm is None: if flashinfer_comm is None:
logger.warning( logger.warning(
"Flashinfer is not installed or comm module not found, " "Flashinfer is not installed or comm module not found, "
...@@ -730,7 +715,7 @@ class AllReduceFusionPass(VllmPatternMatcherPass): ...@@ -730,7 +715,7 @@ class AllReduceFusionPass(VllmPatternMatcherPass):
self.tp_size, self.tp_size,
) )
return return
element_size = 4 if use_fp32_lamport else 2 element_size = torch.tensor([], dtype=self.model_dtype).element_size()
self.max_token_num = max_size // (self.hidden_dim * element_size) self.max_token_num = max_size // (self.hidden_dim * element_size)
# take the min to save workspace size and we'll never use more # take the min to save workspace size and we'll never use more
# than max_num_batched_tokens anyways # than max_num_batched_tokens anyways
...@@ -744,23 +729,19 @@ class AllReduceFusionPass(VllmPatternMatcherPass): ...@@ -744,23 +729,19 @@ class AllReduceFusionPass(VllmPatternMatcherPass):
scope="global", scope="global",
) )
self.ipc_handles, workspace_tensor = ( self.workspace = flashinfer_comm.create_allreduce_fusion_workspace(
flashinfer_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion( backend="trtllm",
tp_rank=rank, world_size=self.tp_size,
tp_size=self.tp_size, rank=rank,
max_token_num=self.max_token_num, max_token_num=self.max_token_num,
hidden_dim=self.hidden_dim, hidden_dim=self.hidden_dim,
group=self.group, dtype=self.model_dtype,
use_fp32_lamport=use_fp32_lamport,
)
) )
global _FI_WORKSPACE_TENSOR global _FI_WORKSPACE
_FI_WORKSPACE_TENSOR = workspace_tensor _FI_WORKSPACE = self.workspace
self.allreduce_params = FlashInferFusedAllReduceParams( self.allreduce_params = FlashInferFusedAllReduceParams(
rank=rank,
world_size=self.tp_size, world_size=self.tp_size,
use_fp32_lamport=use_fp32_lamport,
max_token_num=self.max_token_num, max_token_num=self.max_token_num,
) )
...@@ -832,7 +813,6 @@ class AllReduceFusionPass(VllmPatternMatcherPass): ...@@ -832,7 +813,6 @@ class AllReduceFusionPass(VllmPatternMatcherPass):
def __del__(self) -> None: def __del__(self) -> None:
if getattr(self, "disabled", True): if getattr(self, "disabled", True):
return return
if flashinfer_comm is not None: if getattr(self, "workspace", None) is not None:
flashinfer_comm.trtllm_destroy_ipc_workspace_for_all_reduce( with contextlib.suppress(Exception):
self.ipc_handles, self.group self.workspace.destroy()
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment