tpu_distributed_utils.py 7.64 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
from collections import OrderedDict
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_xla.distributed.spmd as xs
from torch.nn.parameter import Parameter

from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)

logger = init_logger(__name__)


class XlaQKVParallelLinear(nn.Module):

    def __init__(self,
                 qkv_linear: nn.Module,
                 mesh: Optional["xs.Mesh"] = None):
        super().__init__()
        assert isinstance(qkv_linear, QKVParallelLinear)
        self.skip_bias_add = qkv_linear.skip_bias_add
        self.return_bias = qkv_linear.return_bias
        assert qkv_linear.tp_size == 1, "TP > 1 is only supported under SPMD."

        self.q_weight: Parameter
        self.k_weight: Parameter
        self.v_weight: Parameter
        self.q_bias: Optional[Parameter]
        self.k_bias: Optional[Parameter]
        self.v_bias: Optional[Parameter]
        self._load_weights_from_qkv_linear(qkv_linear)
        if mesh is not None:
            self._shard_weight(mesh)

    def _shard_weight(self, mesh: "xs.Mesh"):
        self.q_weight = Parameter(self.q_weight.to('xla'), requires_grad=False)
        self.k_weight = Parameter(self.k_weight.to('xla'), requires_grad=False)
        self.v_weight = Parameter(self.v_weight.to('xla'), requires_grad=False)
        xs.mark_sharding(self.q_weight, mesh, ('x', None))
        xs.mark_sharding(self.k_weight, mesh, ('x', None))
        xs.mark_sharding(self.v_weight, mesh, ('x', None))
        if self.q_bias is not None:
            assert self.k_bias is not None and self.v_bias is not None, \
                "QKVParallelLinear should have q, k, and v biases together."
            self.q_bias = Parameter(self.q_bias.to('xla'), requires_grad=False)
            xs.mark_sharding(self.q_bias, mesh, ('x', ))
            self.k_bias = Parameter(self.k_bias.to('xla'), requires_grad=False)
            xs.mark_sharding(self.k_bias, mesh, ('x', ))
            self.v_bias = Parameter(self.v_bias.to('xla'), requires_grad=False)
            xs.mark_sharding(self.v_bias, mesh, ('x', ))

    def _load_weights_from_qkv_linear(self, qkv_linear: nn.Module):
        q_proj_size, k_proj_size, _ = qkv_linear.output_sizes
        # The weight of qkv linear is a concatenation of q, k, and v weights
        # along the output dimension.
        qkv_weight = qkv_linear.weight.data.cpu()
        q_weight = Parameter(qkv_weight[:q_proj_size], requires_grad=False)
        k_weight = Parameter(qkv_weight[q_proj_size:q_proj_size + k_proj_size],
                             requires_grad=False)
        v_weight = Parameter(qkv_weight[q_proj_size + k_proj_size:],
                             requires_grad=False)
        self.register_parameter("q_weight", q_weight)
        self.register_parameter("k_weight", k_weight)
        self.register_parameter("v_weight", v_weight)

        if qkv_linear.bias is not None:
            q_bias = Parameter(qkv_linear.bias[:q_proj_size],
                               requires_grad=False)
            k_bias = Parameter(qkv_linear.bias[q_proj_size:q_proj_size +
                                               k_proj_size],
                               requires_grad=False)
            v_bias = Parameter(qkv_linear.bias[q_proj_size + k_proj_size:],
                               requires_grad=False)
            self.register_parameter("q_bias", q_bias)
            self.register_parameter("k_bias", k_bias)
            self.register_parameter("v_bias", v_bias)
        else:
            self.register_parameter("q_bias", None)
            self.register_parameter("k_bias", None)
            self.register_parameter("v_bias", None)

    def forward(self, input):
        # Same forward functionality as QKVParallelLinear, but doing qkv porj
        # separately.
        q_bias = self.q_bias if not self.skip_bias_add else None
        k_bias = self.k_bias if not self.skip_bias_add else None
        v_bias = self.v_bias if not self.skip_bias_add else None
        q_proj = F.linear(input, self.q_weight, q_bias)
        k_proj = F.linear(input, self.k_weight, k_bias)
        v_proj = F.linear(input, self.v_weight, v_bias)
        # The q/k/v projections will be split outside of the QKVParallelLinear.
        # Because we are replacing XlaQKVParallelLinear with the
        # QKVParallelLinear, we need to concatenate q, k, and v projections to
        # match the output shape of the QKVParallelLinear implementation even if
        # it seems to be redundant.
        # The concat and the following split will be noop, and should be
        # optimized away by the compiler.
        qkv_proj = torch.cat([q_proj, k_proj, v_proj], dim=-1)
        output_bias = torch.cat([q_bias, k_bias, v_bias], dim=-1) if \
                            self.skip_bias_add else None
        if not self.return_bias:
            return qkv_proj
        return qkv_proj, output_bias


def partition_column_parallel_linear(layer: torch.nn.Module,
                                     mesh: xs.Mesh) -> torch.nn.Module:
    assert isinstance(layer, ColumnParallelLinear)
    xs.mark_sharding(layer.weight, mesh, ('x', None))
    logger.debug("Applied column-parallel sharding to %s", layer)
    return layer


def partition_row_parallel_linear(layer: torch.nn.Module,
                                  mesh: xs.Mesh) -> torch.nn.Module:
    assert isinstance(layer, RowParallelLinear)
    xs.mark_sharding(layer.weight, mesh, (None, 'x'))
    logger.debug("Applied row-parallel sharding to %s", layer)
    return layer


def partition_qkv_parallel_linear(layer: torch.nn.Module,
                                  mesh: xs.Mesh) -> torch.nn.Module:
    assert isinstance(layer, QKVParallelLinear)
    xla_layer = XlaQKVParallelLinear(layer, mesh)
    logger.debug("Applied qkv parallel sharding to %s", layer)
    return xla_layer


MODULE_TYPE_TO_WRAPPING_FUNC = OrderedDict([
    ("QKVParallelLinear", partition_qkv_parallel_linear),
    ("ColumnParallelLinear", partition_column_parallel_linear),
    ("RowParallelLinear", partition_row_parallel_linear),
])


def get_fqn(module):
    # Get the fully qualified name of the module
    return module.__class__.__qualname__


def shard_model(model: torch.nn.Module, mesh: "xs.Mesh") -> None:
    """
    Recursively check a PyTorch model and apply appropriate sharding based on 
    the MODULE_TYPE_TO_WRAPPING_FUNC mapping.
    
    Args:
        model: torch.nn.Module to process
        mesh: An XLA SPMD mesh object used for sharding
    """

    def _process_module(module, name=None, parent=None):
        for module_type, wrapping_func in MODULE_TYPE_TO_WRAPPING_FUNC.items():
            if get_fqn(module) == module_type:
                wrapped_module = wrapping_func(module, mesh)

                assert parent is not None and name is not None, (
                    "Top Level module is not expected to be wrapped.")
                if wrapped_module is not module:
                    # Wrapped module and module are different py object.
                    # The original module should be replaced by the
                    # wrapped_module.
                    logger.debug("replace %s with %s", module, wrapped_module)
                    setattr(parent, name, wrapped_module)

                module = wrapped_module
                break

        for child_name, child_module in list(module.named_children()):
            _process_module(child_module, child_name, module)

    _process_module(model)