Unverified Commit 81d954f4 authored by Tyler Michael Smith's avatar Tyler Michael Smith Committed by GitHub
Browse files

[WideEP] Remove naive all2all. Use allgather_reducescatter instead (#33728)


Signed-off-by: default avatarTyler Michael Smith <tlrmchlsmth@gmail.com>
parent 47fcb8ca
......@@ -38,115 +38,6 @@ if has_flashinfer_nvlink_one_sided():
logger = init_logger(__name__)
class NaiveAll2AllManager(All2AllManagerBase):
"""
A naive implementation of all2all communication.
It uses all-reduce under the hood, which is not
efficient at all. The main purpose is for testing and
debugging.
"""
def __init__(self, cpu_group, tcp_store_group=None):
super().__init__(cpu_group, tcp_store_group)
def naive_multicast(
self,
x: torch.Tensor,
cu_tokens_across_sp_cpu: torch.Tensor,
is_sequence_parallel: bool,
) -> torch.Tensor:
assert len(x.shape) == 2
buffer = torch.empty(
(cu_tokens_across_sp_cpu[-1], x.size(1)), device=x.device, dtype=x.dtype
)
rank = self.rank if is_sequence_parallel else self.dp_rank
world_size = self.world_size if is_sequence_parallel else self.dp_world_size
start = 0 if rank == 0 else cu_tokens_across_sp_cpu[rank - 1]
end = cu_tokens_across_sp_cpu[rank]
buffer[start:end, :].copy_(x)
for idx in range(world_size):
start = 0 if idx == 0 else cu_tokens_across_sp_cpu[idx - 1]
end = cu_tokens_across_sp_cpu[idx]
get_ep_group().broadcast(buffer[start:end, :], idx)
return buffer
def dispatch_router_logits(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
if extra_tensors is not None:
raise NotImplementedError(
"extra_tensors is not supported for NaiveAll2AllManager"
)
sp_size = self.tp_group.world_size if is_sequence_parallel else 1
dp_metadata = get_forward_context().dp_metadata
assert dp_metadata is not None
cu_tokens_across_sp_cpu = dp_metadata.cu_tokens_across_sp(sp_size)
hidden_states = self.naive_multicast(
hidden_states, cu_tokens_across_sp_cpu, is_sequence_parallel
)
router_logits = self.naive_multicast(
router_logits, cu_tokens_across_sp_cpu, is_sequence_parallel
)
return hidden_states, router_logits
def dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if extra_tensors is not None:
raise NotImplementedError(
"extra_tensors is not supported for NaiveAll2AllManager"
)
sp_size = self.tp_group.world_size if is_sequence_parallel else 1
dp_metadata = get_forward_context().dp_metadata
assert dp_metadata is not None
cu_tokens_across_sp_cpu = dp_metadata.cu_tokens_across_sp(sp_size)
hidden_states = self.naive_multicast(
hidden_states, cu_tokens_across_sp_cpu, is_sequence_parallel
)
topk_weights = self.naive_multicast(
topk_weights, cu_tokens_across_sp_cpu, is_sequence_parallel
)
topk_ids = self.naive_multicast(
topk_ids, cu_tokens_across_sp_cpu, is_sequence_parallel
)
return hidden_states, topk_weights, topk_ids
def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor:
ep_rank = self.rank if is_sequence_parallel else self.dp_rank
dp_metadata = get_forward_context().dp_metadata
assert dp_metadata is not None
sp_size = self.tp_group.world_size if is_sequence_parallel else 1
cu_tokens_across_sp_cpu = dp_metadata.cu_tokens_across_sp(sp_size)
start = 0 if ep_rank == 0 else cu_tokens_across_sp_cpu[ep_rank - 1]
end = cu_tokens_across_sp_cpu[ep_rank]
all_hidden_states = get_ep_group().all_reduce(hidden_states)
hidden_states = all_hidden_states[start:end, :]
return hidden_states
def destroy(self):
pass
class AgRsAll2AllManager(All2AllManagerBase):
"""
An implementation of all2all communication based on
......
......@@ -49,18 +49,19 @@ class CpuCommunicator(DeviceCommunicatorBase):
self.supports_tensor_dict = isinstance(self.dist_module, _CPUSHMDistributed)
if self.use_all2all:
if self.all2all_backend != "naive": # type: ignore[has-type]
if self.all2all_backend not in (
"naive",
"allgather_reducescatter",
): # type: ignore[has-type]
logger.warning(
"`%s` all2all manager is not supported on CPU. "
"Falling back to `naive` all2all manager for CPU.",
"Falling back to `allgather_reducescatter` manager.",
self.all2all_backend, # type: ignore[has-type]
)
self.all2all_backend = "naive"
if self.all2all_backend == "naive":
from .all2all import NaiveAll2AllManager
from .all2all import AgRsAll2AllManager
self.all2all_manager = NaiveAll2AllManager(self.cpu_group)
logger.info("Using naive all2all manager.")
self.all2all_manager = AgRsAll2AllManager(self.cpu_group)
logger.info("Using allgather_reducescatter all2all manager.")
def _all_group_ranks_share_shm_group_name(self) -> bool:
"""
......
......@@ -115,13 +115,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
self.qr_comm = QuickAllReduce(group=self.cpu_group, device=self.device)
if self.use_all2all:
if self.all2all_backend == "naive":
from .all2all import NaiveAll2AllManager
self.all2all_manager = NaiveAll2AllManager(
self.cpu_group, tcp_store_group
)
elif self.all2all_backend == "allgather_reducescatter":
if self.all2all_backend in ("naive", "allgather_reducescatter"):
from .all2all import AgRsAll2AllManager
self.all2all_manager = AgRsAll2AllManager(
......
......@@ -23,13 +23,7 @@ class XpuCommunicator(DeviceCommunicatorBase):
):
super().__init__(cpu_group, device, device_group, unique_name)
if self.use_all2all:
if self.all2all_backend == "naive":
from .all2all import NaiveAll2AllManager
self.all2all_manager = NaiveAll2AllManager(self.cpu_group)
logger.info("Using naive all2all manager.")
elif self.all2all_backend == "allgather_reducescatter":
if self.all2all_backend in ("naive", "allgather_reducescatter"):
from .all2all import AgRsAll2AllManager
self.all2all_manager = AgRsAll2AllManager(self.cpu_group)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment