Unverified Commit ae268b63 authored by Xin Li's avatar Xin Li Committed by GitHub
Browse files

Fix Flashinfer Allreduce+Norm enable disable calculation based on...


Fix Flashinfer Allreduce+Norm enable disable calculation based on `fi_allreduce_fusion_max_token_num` (#21325)
Signed-off-by: default avatarXIn Li <xinli@nvidia.com>
parent 35366ae5
...@@ -159,6 +159,9 @@ if flashinfer_comm is not None: ...@@ -159,6 +159,9 @@ if flashinfer_comm is not None:
6: MiB // 2, # 512KB 6: MiB // 2, # 512KB
8: MiB // 2, # 512KB 8: MiB // 2, # 512KB
} }
# opt for a more conservative default value
# when world size is not in _FI_MAX_SIZES
_DEFAULT_FI_MAX_SIZE = MiB // 2
def call_trtllm_fused_allreduce_norm( def call_trtllm_fused_allreduce_norm(
allreduce_in: torch.Tensor, allreduce_in: torch.Tensor,
...@@ -173,12 +176,16 @@ if flashinfer_comm is not None: ...@@ -173,12 +176,16 @@ if flashinfer_comm is not None:
max_token_num: int, max_token_num: int,
norm_out: Optional[torch.Tensor] = None, norm_out: Optional[torch.Tensor] = None,
) -> None: ) -> None:
use_flashinfer = allreduce_in.shape[0] * allreduce_in.shape[
1] * allreduce_in.element_size() <= min( num_tokens, hidden_size = allreduce_in.shape
_FI_MAX_SIZES[world_size], element_size = allreduce_in.element_size()
max_token_num * allreduce_in.shape[0] * current_tensor_size = num_tokens * hidden_size * element_size
allreduce_in.element_size(), max_fusion_size = max_token_num * hidden_size * element_size
) use_flashinfer = current_tensor_size <= min(
_FI_MAX_SIZES.get(world_size, _DEFAULT_FI_MAX_SIZE),
max_fusion_size,
)
if use_flashinfer: if use_flashinfer:
assert (_FI_WORKSPACE_TENSOR is not None assert (_FI_WORKSPACE_TENSOR is not None
), "Flashinfer must be enabled when using flashinfer" ), "Flashinfer must be enabled when using flashinfer"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment