"tests/python/vscode:/vscode.git/clone" did not exist on "cf5c19302e94ddf274dd06bbd6d4ffee56a5d9ed"
Unverified Commit 9d834fdc authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

Revert "feat: update flashinfer ar oneshot params (#8687)" (#9054)

parent b3279251
...@@ -441,6 +441,7 @@ class CommunicateWithAllReduceAndLayerNormFn: ...@@ -441,6 +441,7 @@ class CommunicateWithAllReduceAndLayerNormFn:
and _is_flashinfer_available and _is_flashinfer_available
and hasattr(layernorm, "forward_with_allreduce_fusion") and hasattr(layernorm, "forward_with_allreduce_fusion")
and global_server_args_dict["enable_flashinfer_allreduce_fusion"] and global_server_args_dict["enable_flashinfer_allreduce_fusion"]
and hidden_states.shape[0] <= 128
): ):
hidden_states, residual = layernorm.forward_with_allreduce_fusion( hidden_states, residual = layernorm.forward_with_allreduce_fusion(
hidden_states, residual hidden_states, residual
......
...@@ -125,7 +125,7 @@ def flashinfer_allreduce_residual_rmsnorm( ...@@ -125,7 +125,7 @@ def flashinfer_allreduce_residual_rmsnorm(
weight: torch.Tensor, weight: torch.Tensor,
eps: float = 1e-6, eps: float = 1e-6,
max_token_num: int = 128, max_token_num: int = 128,
use_oneshot: Optional[bool] = None, use_oneshot: bool = True,
trigger_completion_at_end: bool = False, trigger_completion_at_end: bool = False,
fp32_acc: bool = False, fp32_acc: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment