import os from typing import Callable, List, Optional, Tuple, Union import torch import torch.distributed as dist from . import deep_ep_cpp from .deep_ep_cpp import Config, EventHandle from .utils import EventOverlap, check_nvlink_connections class Buffer: """ The core expert-parallel (EP) communication buffers for Mixture of Experts (MoE) model, which supports: - high-throughput intranode all-to-all (dispatch and combine, using NVLink) - high-throughput internode all-to-all (dispatch and combine, using RDMA and NVLink) - low-latency all-to-all (dispatch and combine, using RDMA) Attributes: num_sms: the SMs used in high-throughput kernels. rank: the local rank number. group_size: the number of ranks in the group. group: the communication group. num_nvl_bytes: the buffer size for intranode NVLink communication. num_rdma_bytes: the buffer size for internode (also for intranode with low-latency mode) RDMA communication. runtime: the C++ runtime. """ num_sms: int = 24 def __init__( self, group: dist.ProcessGroup, num_nvl_bytes: int = 0, num_rdma_bytes: int = 0, low_latency_mode: bool = False, num_qps_per_rank: int = 24, allow_nvlink_for_low_latency_mode: bool = True, allow_mnnvl: bool = False, explicitly_destroy: bool = False, enable_shrink: bool = False, ) -> None: """ Initialize the communication buffer. Arguments: group: the communication group. num_nvl_bytes: the buffer size for intranode NVLink communication. num_rdma_bytes: the buffer size for internode (also for intranode with low-latency mode) RDMA communication. low_latency_mode: whether to enable low-latency mode. num_qps_per_rank: the number of QPs for RDMA, the low-latency mode requires that this number equals to the number of local experts. allow_nvlink_for_low_latency_mode: whether allow NVLink traffic for low-latency mode, you should notice this is somehow incompatible with the hook-based overlapping. Warning: PCIe connections may lead to errors due to memory ordering issues, please make sure all connections are via NVLink. allow_mnnvl: whether to allow MNNVL explicitly_destroy: If this flag is set to True, you need to explicitly call `destroy()` to release resources; otherwise, the resources will be released by the destructor. Note: Releasing resources in the destructor may cause Python's exception handling process to hang. enable_shrink: whether to enable shrink mode. The enable mode allocates a mask buffer to support masking ranks dynamically. """ check_nvlink_connections(group) # Initialize the CPP runtime self.rank = group.rank() self.group_size = group.size() self.group = group self.num_nvl_bytes = num_nvl_bytes self.num_rdma_bytes = num_rdma_bytes self.low_latency_mode = low_latency_mode self.explicitly_destroy = explicitly_destroy self.enable_shrink = enable_shrink self.runtime = deep_ep_cpp.Buffer( self.rank, self.group_size, num_nvl_bytes, num_rdma_bytes, low_latency_mode, explicitly_destroy, enable_shrink ) # Synchronize device IDs device_ids = [ None, ] * self.group_size local_device_id = self.runtime.get_local_device_id() dist.all_gather_object(device_ids, local_device_id, group) # Synchronize IPC handles ipc_handles = [ None, ] * self.group_size local_ipc_handle = self.runtime.get_local_ipc_handle() dist.all_gather_object(ipc_handles, local_ipc_handle, group) # Synchronize DUSHMEM unique IDs root_unique_id = None if self.runtime.get_num_rdma_ranks() > 1 or low_latency_mode: # Enable IBGDA self._setup_device_hca_mapping() assert num_qps_per_rank > 0 os.environ["DUSHMEM_DISABLE_P2P"] = "0" if allow_nvlink_for_low_latency_mode else "1" # os.environ["DUSHMEM_IB_ENABLE_IBGDA"] = "1" os.environ["DUSHMEM_IB_ENABLE_IBGDA"] = "0" # force_use_ibrc os.environ["DUSHMEM_IBGDA_NIC_HANDLER"] = "gpu" os.environ["DUSHMEM_IB_DISABLE_DMABUF"] = "1" os.environ["DUSHMEM_ENABLE_NIC_PE_MAPPING"] = "1" os.environ["DUSHMEM_IBGDA_NUM_RC_PER_PE"] = f"{num_qps_per_rank}" # Make sure QP depth is always larger than the number of on-flight WRs, so that we can skip WQ slot check os.environ["DUSHMEM_QP_DEPTH"] = os.environ.get("DUSHMEM_QP_DEPTH", "1024") # Reduce gpu memory usage # 6 default teams + 1 extra team os.environ["DUSHMEM_MAX_TEAMS"] = "7" # Disable NVLink SHArP os.environ["DUSHMEM_DISABLE_NVLS"] = "1" # NOTES: DUSHMEM initialization requires at least 256 MiB os.environ["DUSHMEM_CUMEM_GRANULARITY"] = f"{2 ** 29}" if not allow_mnnvl: # Disable multi-node NVLink detection os.environ["DUSHMEM_DISABLE_MNNVL"] = "1" # Synchronize using the root ID dushmem_unique_ids = [ None, ] * self.group_size if (low_latency_mode and self.rank == 0) or ( not low_latency_mode and self.runtime.get_rdma_rank() == 0 ): root_unique_id = self.runtime.get_local_dushmem_unique_id() dist.all_gather_object(dushmem_unique_ids, root_unique_id, group) root_unique_id = dushmem_unique_ids[ 0 if low_latency_mode else self.runtime.get_root_rdma_rank(True) ] # Make CPP runtime available self.runtime.sync(device_ids, ipc_handles, root_unique_id) assert self.runtime.is_available() def _setup_device_hca_mapping(self): """ Set up device to NIC mapping using DEEP_EP_DEVICE_TO_HCA_MAPPING environment variable. The mapping format is: "0:mlx5_0:1,1:mlx5_1:1,..." where each entry maps a CUDA device ID to an HCA name separated by colon. HCA name can include additional suffixes like ":1". """ if 'DEEP_EP_DEVICE_TO_HCA_MAPPING' in os.environ: device_mapping = {} mapping_str = os.environ['DEEP_EP_DEVICE_TO_HCA_MAPPING'] # Parse mapping string like "0:mlx5_0:1,1:mlx5_1:1,..." for mapping in mapping_str.split(','): assert ':' in mapping, f"Invalid mapping format '{mapping}' in DEEP_EP_DEVICE_TO_HCA_MAPPING. Expected format: ':'" parts = mapping.split(':', 1) # Split only on first colon device_id = int(parts[0]) hca_name = parts[1] # Keep the rest as HCA name (including :1) device_mapping[device_id] = hca_name # Get current device and set appropriate HCA current_device = torch.cuda.current_device() # # Translate CUDA_VISIBLE_DEVICES # if 'CUDA_VISIBLE_DEVICES' in os.environ: # visible_devices = os.environ['CUDA_VISIBLE_DEVICES'].split(",") # assert len(visible_devices) > current_device, f"CUDA_VISIBLE_DEVICES has {len(visible_devices)} entries which is fewer than the current device {current_device}" # assert visible_devices[current_device].isdigit(), f"DEEP_EP_DEVICE_TO_HCA_MAPPING requires CUDA_VISIBLE_DEVICES to contain integer indices" # current_device = int(visible_devices[current_device]) assert current_device in device_mapping, f"Current HIP device {current_device} not found in DEEP_EP_DEVICE_TO_HCA_MAPPING" os.environ['DUSHMEM_ENABLE_PE_MAPPING'] = '1' os.environ['DUSHMEM_HCA_LIST'] = device_mapping[current_device] def destroy(self): """ Destroy the cpp runtime and release resources. """ assert self.explicitly_destroy, "`explicitly_destroy` flag must be set" self.runtime.destroy() self.runtime = None # @staticmethod # def is_sm90_compiled(): # return deep_ep_cpp.is_sm90_compiled() @staticmethod def set_num_sms(new_num_sms: int) -> None: """ Set the number of SMs to use in high-throughput kernels. Arguments: new_num_sms: the new number to be set. """ assert new_num_sms % 2 == 0, "The SM count must be even" Buffer.num_sms = new_num_sms @staticmethod def capture() -> EventOverlap: """ Capture a CUDA event on the current stream, i.e. `torch.cuda.current_stream()`. Returns: event: the captured event. """ return EventOverlap(EventHandle()) @staticmethod def get_low_latency_rdma_size_hint( num_max_dispatch_tokens_per_rank: int, hidden: int, num_ranks: int, num_experts: int ) -> int: """ Get a minimum size requirement for the RDMA buffer. The size calculation will be done with BF16. Arguments: num_max_dispatch_tokens_per_rank: the maximum number of tokens to dispatch, all the ranks must hold the same value. hidden: the hidden dimension of each token. num_ranks: the number of EP group ranks. num_experts: the number of all experts. Returns: size: the RDMA buffer size recommended. """ return deep_ep_cpp.get_low_latency_rdma_size_hint( num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts ) def get_comm_stream(self) -> torch.Stream: """ Get the communication stream. Returns: stream: the communication stream. """ ts: torch.Stream = self.runtime.get_comm_stream() return torch.cuda.Stream( stream_id=ts.stream_id, device_index=ts.device_index, device_type=ts.device_type ) def get_local_buffer_tensor( self, dtype: torch.dtype, size: Optional[torch.Size] = None, offset: int = 0, use_rdma_buffer: bool = False, ) -> torch.Tensor: """ Get the raw buffer (slice supported) as a PyTorch tensor. Argument: dtype: the data type (PyTorch `dtype`) for the tensor. size: the slice size (by elements) to get from the buffer. offset: the offset of the beginning element. use_rdma_buffer: whether to return the RDMA buffer. """ tensor = self.runtime.get_local_buffer_tensor(dtype, offset, use_rdma_buffer) if size is None: return tensor assert tensor.numel() >= size.numel() return tensor[: size.numel()].view(size) @staticmethod def _unpack_bias(bias: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]): bias_0, bias_1 = None, None if isinstance(bias, torch.Tensor): bias_0 = bias elif isinstance(bias, tuple): assert len(bias) == 2 bias_0, bias_1 = bias return bias_0, bias_1 @staticmethod def get_dispatch_config(num_ranks: int) -> Config: """ Get a recommended dispatch config. Argument: num_ranks: the number of ranks. Returns: config: the recommended config. """ # TODO: automatically tune config_map = { 2: Config(Buffer.num_sms, 24, 256, 6, 128), 4: Config(Buffer.num_sms, 6, 256, 6, 128), 8: Config(Buffer.num_sms, 6, 256, 6, 128), # 16: Config(Buffer.num_sms, 36, 288, 20, 128), 16: Config(Buffer.num_sms, 8, 512, 16, 128), 24: Config(Buffer.num_sms, 8, 288, 32, 128), 32: Config(Buffer.num_sms, 32, 288, 32, 128), 64: Config(Buffer.num_sms, 20, 288, 28, 128), 128: Config(Buffer.num_sms, 20, 560, 32, 128), 144: Config(Buffer.num_sms, 32, 720, 12, 128), 160: Config(Buffer.num_sms, 28, 720, 12, 128), } assert num_ranks in config_map, f"Unsupported number of EP ranks: {num_ranks}" return config_map[num_ranks] @staticmethod def get_combine_config(num_ranks: int) -> Config: """ Get a recommended combine config. Argument: num_ranks: the number of ranks. Returns: config: the recommended config. """ # TODO: automatically tune config_map = { 2: Config(Buffer.num_sms, 10, 256, 6, 128), 4: Config(Buffer.num_sms, 9, 256, 6, 128), 8: Config(Buffer.num_sms, 4, 256, 6, 128), # 16: Config(Buffer.num_sms, 4, 288, 12, 128), 16: Config(Buffer.num_sms, 8, 512, 16, 128), 24: Config(Buffer.num_sms, 1, 288, 8, 128), 32: Config(Buffer.num_sms, 1, 288, 8, 128), 64: Config(Buffer.num_sms, 1, 288, 20, 128), 128: Config(Buffer.num_sms, 1, 560, 12, 128), 144: Config(Buffer.num_sms, 2, 720, 8, 128), 160: Config(Buffer.num_sms, 2, 720, 8, 128), } assert num_ranks in config_map, f"Unsupported number of EP ranks: {num_ranks}" return config_map[num_ranks] # noinspection PyTypeChecker def get_dispatch_layout( self, topk_idx: torch.Tensor, num_experts: int, previous_event: Optional[EventOverlap] = None, async_finish: bool = False, allocate_on_comm_stream: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor, EventOverlap]: """ Calculate the layout required for later communication. Arguments: topk_idx: `[num_tokens, num_topk]`, dtype must be `torch.int64`, the expert indices selected by each token, `-1` means no selections. num_experts: the number of experts. previous_event: the event to wait before actually executing the kernel. async_finish: the current stream will not wait for the communication kernels to be finished if set. allocate_on_comm_stream: control whether all the allocated tensors' ownership to be on the communication stream. Returns: num_tokens_per_rank: `[num_ranks]` with `torch.int`, the number of tokens to be sent to each rank. num_tokens_per_rdma_rank: `[num_rdma_ranks]` with `torch.int`, the number of tokens to be sent to each RDMA rank (with the same GPU index), return `None` for intranode settings. num_tokens_per_expert: `[num_experts]` with `torch.int`, the number of tokens to be sent to each expert. is_token_in_rank: `[num_tokens, num_ranks]` with `torch.bool`, whether a token be sent to a rank. event: the event after executing the kernel (valid only if `async_finish` is set). """ num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, is_token_in_rank, event = ( self.runtime.get_dispatch_layout( topk_idx, num_experts, getattr(previous_event, "event", None), async_finish, allocate_on_comm_stream, ) ) return ( num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, is_token_in_rank, EventOverlap(event), ) # noinspection PyTypeChecker def dispatch( self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], handle: Optional[Tuple] = None, num_tokens_per_rank: Optional[torch.Tensor] = None, num_tokens_per_rdma_rank: Optional[torch.Tensor] = None, is_token_in_rank: Optional[torch.Tensor] = None, num_tokens_per_expert: Optional[torch.Tensor] = None, topk_idx: Optional[torch.Tensor] = None, topk_weights: Optional[torch.Tensor] = None, expert_alignment: int = 1, num_worst_tokens: int = 0, config: Optional[Config] = None, previous_event: Optional[EventOverlap] = None, async_finish: bool = False, allocate_on_comm_stream: bool = False, num_recv_tokens_per_expert_as_cuda: bool = False, ) -> Tuple[ Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor], List[int], torch.Tensor, Tuple, EventOverlap, ]: """ Dispatch tokens to different ranks, both intranode and internode settings are supported. Intranode kernels require all the ranks should be visible via NVLink. Internode kernels require the ranks in a node should be visible via NVLink, while the ranks with the same GPU index should be visible via RDMA. Arguments: x: `torch.Tensor` or tuple of `torch.Tensor`, for the first type, the shape must be `[num_tokens, hidden]`, and type must be `torch.bfloat16`; for the second type, the first element of the tuple must be shaped as `[num_tokens, hidden]` with type `torch.float8_e4m3fn`, the second must be `[num_tokens, hidden // 128]` (requiring divisible) with type `torch.float`. handle: an optional communication handle, if set, the CPU will reuse the layout information to save some time. num_tokens_per_rank: `[num_ranks]` with `torch.int`, the number of tokens to be sent to each rank. num_tokens_per_rdma_rank: `[num_rdma_ranks]` with `torch.int`, the number of tokens to be sent to each RDMA rank (with the same GPU index), return `None` for intranode settings. is_token_in_rank: `[num_tokens, num_ranks]` with `torch.bool`, whether a token be sent to a rank. num_tokens_per_expert: `[num_experts]` with `torch.int`, the number of tokens to be sent to each expert. topk_idx: `[num_tokens, num_topk]` with `torch.int64`, the expert indices selected by each token, `-1` means no selections. topk_weights: `[num_tokens, num_topk]` with `torch.float`, the expert weights of each token to dispatch. expert_alignment: align the number of tokens received by each local expert to this variable. num_worst_tokens: the worst number of tokens to receive, if specified, there will be no CPU sync, and it will be CUDA-graph compatible. Please also notice that this flag is for intranode only. config: the performance tuning config. previous_event: the event to wait before actually executing the kernel. async_finish: the current stream will not wait for the communication kernels to be finished if set. allocate_on_comm_stream: control whether all the allocated tensors' ownership to be on the communication stream. num_recv_tokens_per_expert_as_cuda: control return num_recv_tokens_per_expert as cuda tensor or python list. Returns: recv_x: received tokens, the same type and tuple as the input `x`, but the number of tokens equals to the received token count. recv_topk_idx: received expert indices. recv_topk_weights: received expert weights. num_recv_tokens_per_expert: Python list or cuda tensor shaped `[num_local_experts]`, the received token count by each local expert, aligned to the input `expert_alignment`. If `num_worst_tokens` is specified, the list will be empty. handle: the returned communication handle. event: the event after executing the kernel (valid only if `async_finish` is set). """ # Default config config = self.get_dispatch_config(self.group_size) if config is None else config # Internode if self.runtime.get_num_rdma_ranks() > 1: assert num_worst_tokens == 0, "Internode dispatch does not support `num_worst_tokens > 0`" return self.internode_dispatch( x, handle, num_tokens_per_rank, num_tokens_per_rdma_rank, is_token_in_rank, num_tokens_per_expert, topk_idx, topk_weights, expert_alignment, config, previous_event, async_finish, allocate_on_comm_stream, ) # Launch the kernel with cached or non-cached mode x, x_scales = x if isinstance(x, tuple) else (x, None) if handle is not None: assert topk_idx is None and topk_weights is None ( rank_prefix_matrix, channel_prefix_matrix, recv_channel_prefix_matrix, recv_src_idx, is_token_in_rank, send_head, ) = handle num_recv_tokens = recv_src_idx.size(0) recv_x, recv_x_scales, _, _, _, _, _, _, _, _, _, event = self.runtime.intranode_dispatch( x, x_scales, None, None, None, is_token_in_rank, None, num_recv_tokens, rank_prefix_matrix, channel_prefix_matrix, expert_alignment, num_worst_tokens, config, getattr(previous_event, "event", None), async_finish, allocate_on_comm_stream, ) return ( (recv_x, recv_x_scales) if x_scales is not None else recv_x, None, None, None, None, EventOverlap(event), ) else: assert ( num_tokens_per_rank is not None and is_token_in_rank is not None and num_tokens_per_expert is not None ) ( recv_x, recv_x_scales, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, num_recv_tokens_per_expert_cuda, rank_prefix_matrix, channel_prefix_matrix, recv_channel_prefix_matrix, recv_src_idx, send_head, event, ) = self.runtime.intranode_dispatch( x, x_scales, topk_idx, topk_weights, num_tokens_per_rank, is_token_in_rank, num_tokens_per_expert, 0, None, None, expert_alignment, num_worst_tokens, config, getattr(previous_event, "event", None), async_finish, allocate_on_comm_stream, ) handle = ( rank_prefix_matrix, channel_prefix_matrix, recv_channel_prefix_matrix, recv_src_idx, is_token_in_rank, send_head, ) return ( (recv_x, recv_x_scales) if x_scales is not None else recv_x, recv_topk_idx, recv_topk_weights, ( num_recv_tokens_per_expert_cuda if num_recv_tokens_per_expert_as_cuda else num_recv_tokens_per_expert_list ), handle, EventOverlap(event), ) # noinspection PyTypeChecker def combine( self, x: torch.Tensor, handle: Tuple, topk_weights: Optional[torch.Tensor] = None, bias: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] = None, config: Optional[Config] = None, previous_event: Optional[EventOverlap] = None, async_finish: bool = False, allocate_on_comm_stream: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], EventOverlap]: """ Combine (reduce) tokens (addition **without** weights) from different ranks, both intranode and internode settings are supported. Intranode kernels require all the ranks should be visible via NVLink. Internode kernels require the ranks in a node should be visible via NVLink, while the ranks with the same GPU index should be visible via RDMA. Arguments: x: `[num_tokens, hidden]` with `torch.bfloat16`, the tokens to send for reducing to its original ranks. handle: a must-set communication handle, you can obtain this from the dispatch function. topk_weights: `[num_tokens, num_topk]` with `torch.float`, the tokens' top-k weights for reducing to its original ranks. config: the performance tuning config. previous_event: the event to wait before actually executing the kernel. async_finish: the current stream will not wait for the communication kernels to be finished if set. allocate_on_comm_stream: control whether all the allocated tensors' ownership to be on the communication stream. Returns: recv_x: the reduced token from its dispatched ranks. recv_topk_weights: the reduced top-k weights from its dispatch ranks. event: the event after executing the kernel (valid only if `async_finish` is set). """ # Default config config = self.get_combine_config(self.group_size) if config is None else config # Internode if self.runtime.get_num_rdma_ranks() > 1: return self.internode_combine( x, handle, topk_weights, bias, config, previous_event, async_finish, allocate_on_comm_stream ) # NOTES: the second `_` is for the sending side, so we should use the third one rank_prefix_matrix, _, channel_prefix_matrix, src_idx, is_recv_token_in_rank, send_head = handle bias_0, bias_1 = Buffer._unpack_bias(bias) # Launch the kernel recv_x, recv_topk_weights, event = self.runtime.intranode_combine( x, topk_weights, bias_0, bias_1, src_idx, rank_prefix_matrix, channel_prefix_matrix, send_head, config, getattr(previous_event, "event", None), async_finish, allocate_on_comm_stream, ) return recv_x, recv_topk_weights, EventOverlap(event) # noinspection PyTypeChecker def internode_dispatch( self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], handle: Optional[Tuple] = None, num_tokens_per_rank: Optional[torch.Tensor] = None, num_tokens_per_rdma_rank: Optional[torch.Tensor] = None, is_token_in_rank: Optional[torch.Tensor] = None, num_tokens_per_expert: Optional[torch.Tensor] = None, topk_idx: Optional[torch.Tensor] = None, topk_weights: Optional[torch.Tensor] = None, expert_alignment: int = 1, config: Optional[Config] = None, previous_event: Optional[EventOverlap] = None, async_finish: bool = False, allocate_on_comm_stream: bool = False, ) -> Tuple[ Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor], List[int], Tuple, EventOverlap, ]: """ Internode dispatch implementation, for more details, please refer to the `dispatch` docs. Normally, you should not directly call this function. """ assert config is not None # Launch the kernel with cached or non-cached mode x, x_scales = x if isinstance(x, tuple) else (x, None) if handle is not None: assert topk_idx is None and topk_weights is None ( is_token_in_rank, rdma_channel_prefix_matrix, gbl_channel_prefix_matrix, recv_rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, recv_gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum, recv_src_meta, send_rdma_head, send_nvl_head, ) = handle num_recv_tokens = recv_src_meta.size(0) num_rdma_recv_tokens = send_nvl_head.size(0) recv_x, recv_x_scales, _, _, _, _, _, _, _, _, _, _, _, _, event = ( self.runtime.internode_dispatch( x, x_scales, topk_idx, topk_weights, None, None, is_token_in_rank, None, num_recv_tokens, num_rdma_recv_tokens, rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum, expert_alignment, config, getattr(previous_event, "event", None), async_finish, allocate_on_comm_stream, ) ) return ( (recv_x, recv_x_scales) if x_scales is not None else recv_x, None, None, None, None, EventOverlap(event), ) else: assert ( num_tokens_per_rank is not None and is_token_in_rank is not None and num_tokens_per_expert is not None ) ( recv_x, recv_x_scales, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, rdma_channel_prefix_matrix, gbl_channel_prefix_matrix, recv_rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, recv_gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum, recv_src_meta, send_rdma_head, send_nvl_head, event, ) = self.runtime.internode_dispatch( x, x_scales, topk_idx, topk_weights, num_tokens_per_rank, num_tokens_per_rdma_rank, is_token_in_rank, num_tokens_per_expert, 0, 0, None, None, None, None, expert_alignment, config, getattr(previous_event, "event", None), async_finish, allocate_on_comm_stream, ) handle = ( is_token_in_rank, rdma_channel_prefix_matrix, gbl_channel_prefix_matrix, recv_rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, recv_gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum, recv_src_meta, send_rdma_head, send_nvl_head, ) return ( (recv_x, recv_x_scales) if x_scales is not None else recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, EventOverlap(event), ) # noinspection PyTypeChecker def internode_combine( self, x: torch.Tensor, handle: Union[tuple, list], topk_weights: Optional[torch.Tensor] = None, bias: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] = None, config: Optional[Config] = None, previous_event: Optional[EventOverlap] = None, async_finish: bool = False, allocate_on_comm_stream: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], EventOverlap]: """ Internode combine implementation, for more details, please refer to the `combine` docs. Normally, you should not directly call this function. """ assert config is not None # Unpack handle and bias ( is_combined_token_in_rank, _, _, rdma_channel_prefix_matrix, rdma_rank_prefix_sum, gbl_channel_prefix_matrix, gbl_rank_prefix_sum, src_meta, send_rdma_head, send_nvl_head, ) = handle bias_0, bias_1 = Buffer._unpack_bias(bias) # Launch the kernel combined_x, combined_topk_weights, event = self.runtime.internode_combine( x, topk_weights, bias_0, bias_1, src_meta, is_combined_token_in_rank, rdma_channel_prefix_matrix, rdma_rank_prefix_sum, gbl_channel_prefix_matrix, send_rdma_head, send_nvl_head, config, getattr(previous_event, "event", None), async_finish, allocate_on_comm_stream, ) return combined_x, combined_topk_weights, EventOverlap(event) def clean_low_latency_buffer( self, num_max_dispatch_tokens_per_rank: int, hidden: int, num_experts: int ) -> None: """ As low-latency kernels require part of the buffer to be zero-initialized, so it is vital to clean the buffer if the buffer is dirty at some time. For example, after running the normal dispatch/combine, you must run this function before executing any low-latency kernel. Arguments: num_max_dispatch_tokens_per_rank: the maximum number of tokens to dispatch, all the ranks must hold the same value. hidden: the hidden dimension of each token. num_experts: the number of all experts. """ self.runtime.clean_low_latency_buffer(num_max_dispatch_tokens_per_rank, hidden, num_experts) # noinspection PyTypeChecker def low_latency_dispatch(self, x: torch.Tensor, topk_idx: torch.Tensor, num_max_dispatch_tokens_per_rank: int, num_experts: int, quant_type: int = 1, quant_group_size: int = 0, fp8_round_scale: bool = False, async_finish: bool = False, return_recv_hook: bool = False) -> \ Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor, Tuple, EventOverlap, Callable]: """ A low-latency implementation for dispatching with IBGDA. This kernel requires all the ranks (no matter intranode or internode) should be visible via RDMA (specifically, IBGDA must be enabled). Warning: as there are only two buffers, and the returned tensors reuse the buffer, you cannot hold more than 2 low-latency kernels' result tensors at a single moment. Arguments: x: `torch.Tensor` with `torch.bfloat16`, shaped as `[num_tokens, hidden]`, only several hidden shapes are supported. The number of tokens to be dispatched must be less than `num_max_dispatch_tokens_per_rank`. topk_idx: `torch.Tensor` with `deep_ep.topk_idx_t` (typically `torch.int64`), shaped as `[num_tokens, num_topk]`, only several top-k shapes are supported. `-1` indices (not selecting any expert) are supported. num_max_dispatch_tokens_per_rank: the maximum number of tokens to dispatch, all the ranks must hold the same value. num_experts: the number of all experts. 量化配置 quant_type: int 量化类型枚举 0 -> None 不量化,保持原始精度 1 -> Int8 使用 INT8 对称量化 2 -> FP8_E4M3 使用 FP8 E4M3 格式 (__HIP_E4M3_FNUZ) 3 -> FP8_UE8M0 使用 DeepSeekV3.1 提出的 UE8M0 格式 (仅支持round_scale=True) 4 -> FP8_E5M2 使用 FP8 E5M2 格式 (__HIP_E5M2_FNUZ) quant_group_size: int 量化分组大小 0 -> 逐token量化 (per-channel) 128-> 每 128 元素一组 (per-group) 量化 fp8_round_scale: bool 是否将 FP8 缩放因子取整为 2 的幂 true -> 缩放因子 = 2^k,硬件零开销 false -> 缩放因子 = 任意浮点,精度更高 异步配置 async_finish: the current stream will not wait for the communication kernels to be finished if set. return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues, but **without actually receiving the data**. You must call the received hook to make sure the data's arrival. If you do not set this flag, the kernel will ensure the data's arrival. Returns: recv_x: a tensor or tuple with received tokens for each expert. - packed_recv_x: 存储接收到的 Token 数据,形状为 `[num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden]`。 数据类型取决于 quant_type: quant_type == 1 -> torch.int8 quant_type == 2 -> torch.float8_e4m3fnuz quant_type == 3 -> torch.float8_e4m3fnuz (UE8M0 使用 E4M3 格式存储) quant_type == 4 -> torch.float8_e5m2fnuz 其他 (非量化) -> torch.bfloat16 - packed_recv_x_scales (可选): 仅在 quant_type > 0 时存在,存储量化的 Scale 值。 形状为 `[num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, scales_col_size]`。 - 当 quant_type == 3 (UE8M0) 时: scales_col_size = hidden // 512 数据类型为 torch.int (内部打包存储 4-bit scale)。 *注意:此模式强制要求 fp8_round_scale=True 且 group_size=128。 - 当 quant_type == 1, 2, 4 时: scales_col_size = hidden // 128 (若使用 group_size) 或 1 (per-channel)。 数据类型为 torch.float32。 Moreover, not all tokens are valid, only some of the `num_max_dispatch_tokens_per_rank * num_ranks` are, as we do not synchronize CPU received count with GPU (also not incompatible with CUDA graph if synced). recv_count: a tensor shaped `[num_local_experts]` with type `torch.int`, indicating how many tokens each expert receives. As mentioned before, not all tokens are valid in `recv_x`. handle: the communication handle to be used in the `low_latency_combine` function. event: the event after executing the kernel (valid only if `async_finish` is set). hook: the receiving hook function (valid only if `return_recv_hook` is set). """ packed_recv_x, packed_recv_x_scales, packed_recv_count, packed_recv_src_info, packed_recv_layout_range, event, hook = \ self.runtime.low_latency_dispatch(x, topk_idx, num_max_dispatch_tokens_per_rank, num_experts, quant_type, quant_group_size, fp8_round_scale, async_finish, return_recv_hook) handle = (packed_recv_src_info, packed_recv_layout_range, num_max_dispatch_tokens_per_rank, x.size(1), num_experts) tensors_to_record = (x, topk_idx, packed_recv_x, packed_recv_x_scales, packed_recv_count, packed_recv_src_info, packed_recv_layout_range) recv_x = (packed_recv_x, packed_recv_x_scales) if (quant_type > 0) else packed_recv_x return recv_x, packed_recv_count, handle, EventOverlap(event, tensors_to_record if async_finish else None), hook # noinspection PyTypeChecker def low_latency_combine(self, x: torch.Tensor, topk_idx: torch.Tensor, topk_weights: torch.Tensor, handle: tuple, zero_copy: bool = False, async_finish: bool = False, return_recv_hook: bool = False, out: Optional[torch.Tensor] = None, combine_wait_recv_cost_stats: Optional[torch.Tensor] = None) -> \ Tuple[torch.Tensor, EventOverlap, Callable]: """ A low-latency implementation for combining tokens (reduce **with weights**) with IBGDA. This kernel requires all the ranks (no matter intranode or internode) should be visible via RDMA (specifically, IBGDA must be enabled). Even for ranks in the same node, NVLink are fully disabled for simplicity. Warning: as there are only two buffers, and the returned tensors reuse the buffer, you can not hold more than 2 low-latency kernels' result tensor at a single moment. Arguments: x: `[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.bfloat16`, the local calculated tokens to be sent to this original rank and reduced. topk_idx: `[num_combined_tokens, num_topk]` with `torch.int64`, the expert indices selected by the dispatched tokens. `-1` indices (not selecting any expert) are supported. Note that, `num_combined_tokens` equals to the number of dispatched tokens. topk_weights: `[num_combined_tokens, num_topk]` with `torch.float`, the expert weights selected by the dispatched tokens. The received tokens will be reduced with the weights in this tensor. handle: the communication handle given by the `dispatch` function. zero_copy: whether the tensor is already copied into the RDMA buffer, should be cooperative with `get_next_low_latency_combine_buffer`. async_finish: the current stream will not wait for the communication kernels to be finished if set. return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues, but **without actually receiving the data**. You must call the received hook to make sure the data's arrival. If you not set this flag, the kernel will ensure the data's arrival. out: the in-place output tensor, if set, the kernel will write the result to this tensor and return it directly. combine_wait_recv_cost_stats: a cumulative time spent waiting to receive each token tensor for statistics, which should have shape `[num_ranks, num_ranks]` and be typed as `torch.int64`. This is useful for detecting and pre-cisely localizing slow anomalies. Returns: combined_x: the reduced token tensor, with shape `[num_combined_tokens, num_topk]` and type `torch.bfloat16`. event: the event after executing the kernel (valid only if `async_finish` is set). hook: the receiving hook function (valid only if `return_recv_hook` is set). """ src_info, layout_range, num_max_dispatch_tokens_per_rank, hidden, num_experts = handle combined_x, event, hook = self.runtime.low_latency_combine(x, topk_idx, topk_weights, src_info, layout_range, combine_wait_recv_cost_stats, num_max_dispatch_tokens_per_rank, num_experts, zero_copy, async_finish, return_recv_hook, out) tensors_to_record = (x, topk_idx, topk_weights, src_info, layout_range, combined_x) return combined_x, EventOverlap(event, tensors_to_record if async_finish else None), hook def get_next_low_latency_combine_buffer(self, handle: object): """ Get the raw registered RDMA buffer tensor for next low-latency combine, so that the next combine kernel can skip the copying. Arguments: handle: the communication handle given by the `dispatch` function. Returns: buffer: the raw RDMA low-latency buffer as a BF16 PyTorch tensor with shape `[num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden]`, you should fill this buffer by yourself. """ src_info, layout_range, num_max_dispatch_tokens_per_rank, hidden, num_experts = handle return self.runtime.get_next_low_latency_combine_buffer(num_max_dispatch_tokens_per_rank, hidden, num_experts)