from typing import TYPE_CHECKING, List, Optional, Tuple import logging import torch import vllm.envs as envs from vllm.distributed.parallel_state import GroupCoordinator, init_model_parallel_group, get_world_group from vllm.distributed import (get_ep_group, get_pp_group, get_dp_group, get_tensor_model_parallel_world_size, tensor_model_parallel_all_gather, get_tensor_model_parallel_rank, tensor_model_parallel_reduce_scatter, get_tp_group) _ENABLE_DP_ATTENTION_FLAG: bool = False _MOE_TP: Optional[GroupCoordinator] = None _ATTN_DP_SIZE = 0 _ATTN_TP_SIZE = 0 _ATTN_TP_RANK = 0 _ATTN_DP_RANK = 0 _MOT_TP_SIZE = 0 _MOT_TP_RANK = 0 def initialize_dp_attention(vllm_config, backend: Optional[str] = None): from vllm.config import VllmConfig assert isinstance(vllm_config, VllmConfig) global _ENABLE_DP_ATTENTION_FLAG, _ATTN_DP_SIZE, _ATTN_TP_SIZE, _ATTN_TP_RANK, _ATTN_DP_RANK, _MOT_TP_SIZE, _MOT_TP_RANK enable_dp_attention = vllm_config.parallel_config.enable_dp_attention _ENABLE_DP_ATTENTION_FLAG = enable_dp_attention # Build the moe tensor model-parallel groups. world_size: int = torch.distributed.get_world_size() rank = torch.distributed.get_rank() data_parallel_size = vllm_config.parallel_config.data_parallel_size pipeline_model_parallel_size = vllm_config.parallel_config.pipeline_parallel_size tensor_model_parallel_size = vllm_config.parallel_config.tensor_parallel_size moe_tp_size = world_size // pipeline_model_parallel_size moe_ep_size = moe_tp_size if vllm_config.parallel_config.enable_expert_parallel else 1 _ATTN_DP_SIZE = data_parallel_size _ATTN_TP_SIZE = tensor_model_parallel_size _ATTN_TP_RANK = get_tensor_model_parallel_rank() _ATTN_DP_RANK = vllm_config.parallel_config.data_parallel_rank _MOT_TP_SIZE = moe_tp_size _MOT_TP_RANK = rank % _MOT_TP_SIZE global _MOE_TP assert _MOE_TP is None, ("moe tensor model parallel group is already initialized") backend = backend or torch.distributed.get_backend( get_world_group().device_group) group_ranks = [] for i in range(pipeline_model_parallel_size): ranks = list( range(i * moe_tp_size, (i + 1) * moe_tp_size) ) group_ranks.append(ranks) # message queue broadcaster is only used in tensor model parallel group _MOE_TP = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend, use_message_queue_broadcaster=True, group_name="moe_tp") def get_attention_tp_size() -> int: assert _ATTN_TP_SIZE is not None, "dp attention not initialized!" return _ATTN_TP_SIZE def get_attention_tp_rank() -> int: assert _ATTN_TP_RANK is not None, "dp attention not initialized!" return _ATTN_TP_RANK def get_moe_tp_group() -> GroupCoordinator: assert _MOE_TP is not None, ("tensor model parallel group is not initialized") return _MOE_TP def get_attention_dp_size() -> int: assert _ATTN_DP_SIZE is not None, "dp attention not initialized!" return _ATTN_DP_SIZE def get_moe_tp_rank() -> int: assert _MOT_TP_RANK is not None, "dp attention not initialized!" return _MOT_TP_RANK def get_moe_tp_size() -> int: assert _MOT_TP_SIZE is not None, "dp attention not initialized!" return _MOT_TP_SIZE def get_attention_tp_group() -> GroupCoordinator: return get_tp_group() def moe_tensor_model_parallel_all_gather(input_: torch.Tensor, dim: int = -1) -> torch.Tensor: """All-gather the input tensor across model parallel group.""" return get_moe_tp_group().all_gather(input_, dim) def moe_tensor_model_parallel_reduce_scatter(input_: torch.Tensor, dim: int = -1) -> torch.Tensor: """Reduce-Scatter the input tensor across model parallel group.""" return get_moe_tp_group().reduce_scatter(input_, dim) def dp_gather( hidden_states: torch.Tensor,)-> torch.Tensor: if get_attention_tp_size() == 1: hidden_states = moe_tensor_model_parallel_all_gather(hidden_states, dim=0) return hidden_states hidden_states = tensor_model_parallel_reduce_scatter(hidden_states, dim=0) hidden_states = moe_tensor_model_parallel_all_gather(hidden_states, dim=0) return hidden_states def dp_reduce_scatter_tensor(hidden_states: torch.Tensor)-> torch.Tensor: if get_moe_tp_group().world_size == get_attention_dp_size(): hidden_states = moe_tensor_model_parallel_reduce_scatter(hidden_states, dim=0) else: hidden_states = moe_tensor_model_parallel_reduce_scatter(hidden_states, dim=0) hidden_states = tensor_model_parallel_all_gather(hidden_states, dim=0) return hidden_states