dp_attention.py 4.91 KB
Newer Older
王敏's avatar
王敏 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
from typing import TYPE_CHECKING, List, Optional, Tuple
import logging

import torch

import vllm.envs as envs
from vllm.distributed.parallel_state import GroupCoordinator, init_model_parallel_group, get_world_group
from vllm.distributed import (get_ep_group, get_pp_group, get_dp_group,
                              get_tensor_model_parallel_world_size,
                              tensor_model_parallel_all_gather,
                              get_tensor_model_parallel_rank,
                              tensor_model_parallel_reduce_scatter,
                              get_tp_group)

_ENABLE_DP_ATTENTION_FLAG: bool = False

_MOE_TP: Optional[GroupCoordinator] = None

_ATTN_DP_SIZE = 0
_ATTN_TP_SIZE = 0
_ATTN_TP_RANK = 0
_ATTN_DP_RANK = 0
_MOT_TP_SIZE = 0
_MOT_TP_RANK = 0


def initialize_dp_attention(vllm_config, backend: Optional[str] = None):
    from vllm.config import VllmConfig
    assert isinstance(vllm_config, VllmConfig)
    global _ENABLE_DP_ATTENTION_FLAG, _ATTN_DP_SIZE, _ATTN_TP_SIZE, _ATTN_TP_RANK, _ATTN_DP_RANK, _MOT_TP_SIZE, _MOT_TP_RANK
    enable_dp_attention = vllm_config.parallel_config.enable_dp_attention
    _ENABLE_DP_ATTENTION_FLAG = enable_dp_attention

    # Build the moe tensor model-parallel groups.
    world_size: int = torch.distributed.get_world_size()
    rank = torch.distributed.get_rank()
    data_parallel_size = vllm_config.parallel_config.data_parallel_size
    pipeline_model_parallel_size = vllm_config.parallel_config.pipeline_parallel_size
    tensor_model_parallel_size = vllm_config.parallel_config.tensor_parallel_size
    moe_tp_size = world_size // pipeline_model_parallel_size
    moe_ep_size = moe_tp_size if vllm_config.parallel_config.enable_expert_parallel else 1

    _ATTN_DP_SIZE = data_parallel_size
    _ATTN_TP_SIZE = tensor_model_parallel_size

    _ATTN_TP_RANK = get_tensor_model_parallel_rank()

    _ATTN_DP_RANK = vllm_config.parallel_config.data_parallel_rank
    _MOT_TP_SIZE = moe_tp_size
    _MOT_TP_RANK = rank % _MOT_TP_SIZE
    

    global _MOE_TP
    assert _MOE_TP is None, ("moe tensor model parallel group is already initialized")

    backend = backend or torch.distributed.get_backend(
        get_world_group().device_group)    
    
    group_ranks = []
    for i in range(pipeline_model_parallel_size):
        ranks = list(
            range(i * moe_tp_size, (i + 1) * moe_tp_size)
        )
        group_ranks.append(ranks)

    # message queue broadcaster is only used in tensor model parallel group
    _MOE_TP = init_model_parallel_group(group_ranks,
                                    get_world_group().local_rank,
                                    backend,
                                    use_message_queue_broadcaster=True,
                                    group_name="moe_tp")

def get_attention_tp_size() -> int:
    assert _ATTN_TP_SIZE is not None, "dp attention not initialized!"
    return _ATTN_TP_SIZE

def get_attention_tp_rank() -> int:
    assert _ATTN_TP_RANK is not None, "dp attention not initialized!"
    return _ATTN_TP_RANK

def get_moe_tp_group() -> GroupCoordinator:
    assert _MOE_TP is not None, ("tensor model parallel group is not initialized")
    return _MOE_TP

def get_attention_dp_size() -> int:
    assert _ATTN_DP_SIZE is not None, "dp attention not initialized!"
    return _ATTN_DP_SIZE

def get_moe_tp_rank() -> int:
    assert _MOT_TP_RANK is not None, "dp attention not initialized!"
    return _MOT_TP_RANK

def get_moe_tp_size() -> int:
    assert _MOT_TP_SIZE is not None, "dp attention not initialized!"
    return _MOT_TP_SIZE

def get_attention_tp_group() -> GroupCoordinator:
    return get_tp_group()

def moe_tensor_model_parallel_all_gather(input_: torch.Tensor,
                                     dim: int = -1) -> torch.Tensor:
    """All-gather the input tensor across model parallel group."""
    return get_moe_tp_group().all_gather(input_, dim)

def moe_tensor_model_parallel_reduce_scatter(input_: torch.Tensor,
                                         dim: int = -1) -> torch.Tensor:
    """Reduce-Scatter the input tensor across model parallel group."""
    return get_moe_tp_group().reduce_scatter(input_, dim)

def dp_gather(
    hidden_states: torch.Tensor,)-> torch.Tensor:
    if get_attention_tp_size() == 1:
        hidden_states = moe_tensor_model_parallel_all_gather(hidden_states, dim=0)
        return hidden_states

    hidden_states = tensor_model_parallel_reduce_scatter(hidden_states, dim=0)
    hidden_states = moe_tensor_model_parallel_all_gather(hidden_states, dim=0)
    return hidden_states

def dp_reduce_scatter_tensor(hidden_states: torch.Tensor)-> torch.Tensor:
    if get_moe_tp_group().world_size == get_attention_dp_size():
        hidden_states = moe_tensor_model_parallel_reduce_scatter(hidden_states, dim=0)
    else:
        hidden_states = moe_tensor_model_parallel_reduce_scatter(hidden_states, dim=0)
        hidden_states = tensor_model_parallel_all_gather(hidden_states, dim=0)

    return hidden_states