transformers.py 10 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
# Copyright 2025 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

# Adapted from
# https://github.com/vllm-project/vllm/blob/a1a2aaadb9122f05667140e39cf67e5736c8b6d6/vllm/model_executor/models/transformers.py
"""Wrapper around `transformers` models"""
import logging
import re
from typing import Iterable, Literal, Optional, Tuple, Union

import torch
from torch import nn
from transformers import AutoModel, PretrainedConfig, PreTrainedModel
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS

from sglang.srt.distributed import divide, get_tensor_model_parallel_world_size
from sglang.srt.layers.linear import (
    ColumnParallelLinear,
    ReplicatedLinear,
    RowParallelLinear,
)
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.vocab_parallel_embedding import (
    ParallelLMHead,
    VocabParallelEmbedding,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader

logger = logging.getLogger(__name__)


def maybe_prefix(prefix: str, name: str) -> str:
    """Add a prefix to a name if the prefix is non-empty.

    Args:
        prefix: The prefix to add. If empty, no prefix will be added.
        name: The name to potentially prefix.

    Returns:
        The string "prefix.name" if prefix was non-empty, otherwise just "name".
    """
    return name if not prefix else f"{prefix}.{name}"


def sglang_flash_attention_forward(
    # Transformers args
    module: torch.nn.Module,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    attention_mask: torch.Tensor,
    # sglang kwargs
    forward_batch: ForwardBatch,
    # Transformers kwargs
    scaling: float = None,
    attention_instances: list[RadixAttention] = None,
    **kwargs,
):
    self_attn: RadixAttention = attention_instances[module.layer_idx]
    if scaling is not None:
        self_attn.scaling = float(scaling)
    hidden = query.shape[-2]
    query, key, value = (x.transpose(1, 2) for x in (query, key, value))
    query, key, value = (x.reshape(hidden, -1) for x in (query, key, value))
    return self_attn.forward(query, key, value, forward_batch=forward_batch), None


ALL_ATTENTION_FUNCTIONS["sglang"] = sglang_flash_attention_forward


class HFColumnParallelLinear(ColumnParallelLinear):

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return super().forward(input)[0]


class HFRowParallelLinear(RowParallelLinear):

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return super().forward(input)[0]


def replace_linear_class(
    linear: nn.Linear,
    style: Literal["colwise", "rowwise"],
    quant_config: QuantizationConfig,
) -> Union[ColumnParallelLinear, RowParallelLinear]:
    """
    Replace nn.Linear with one of vLLM's tensor parallel linear classes.

    Args:
        linear (nn.Linear): `nn.Linear` to be replaced.
        style (str): Tensor parallel style of the new linear, e.g. "colwise".
        quant_config (QuantConfig): Quantization config for the new linear.
    Returns:
        Union[ColumnParallelLinear, RowParallelLinear]: The new linear.
    """

    if not isinstance(style, str):
        raise ValueError(f"Unsupported parallel style type {type(style)}, expected str")

    sglang_linear_cls = {
        "colwise": ColumnParallelLinear,
        "rowwise": RowParallelLinear,
    }.get(style, ReplicatedLinear)

    class HFCompatibleLinear(sglang_linear_cls):
        """
        Wrapper class that removes `output_bias` from returned output.
        """

        @property
        def parent_cls(self) -> type:
            return sglang_linear_cls

        def forward(self, input: torch.Tensor) -> torch.Tensor:
            return super().forward(input)[0]

    return HFCompatibleLinear(
        input_size=linear.in_features,
        output_size=linear.out_features,
        bias=linear.bias is not None,
        quant_config=quant_config,
    )


class TransformersForCausalLM(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ) -> None:
        super().__init__()
        logger.info("Using Transformers backend.")

        self.quant_config = quant_config
        self.config = config
        self.vocab_size = config.vocab_size
        self.unpadded_vocab_size = config.vocab_size

        # model is loaded under set_default_torch_dtype(model_config.dtype)
        self.model: PreTrainedModel = AutoModel.from_config(
            self.config,
            torch_dtype=torch.get_default_dtype(),
            attn_implementation="sglang",
            trust_remote_code=True,
        )

        # Attention modifications (assumes 1 attention op per hidden layer)
        tp_size = get_tensor_model_parallel_world_size()

        # MLP modifications
        self.tensor_parallel(tp_size)

        head_dim = (
            (config.hidden_size // config.num_attention_heads)
            if not hasattr(config, "head_dim")
            else config.head_dim
        )
        self.attention_instances = [
            RadixAttention(
                num_heads=divide(config.num_attention_heads, tp_size),
                head_dim=head_dim,
                # NOTE: We use Llama scale as default, if it's set by
                # Transformers, it's updated in sglang_flash_attention_forward
                scaling=head_dim**-0.5,
                num_kv_heads=divide(config.num_key_value_heads, tp_size),
                layer_id=i,
                quant_config=self.quant_config,
                prefix=f"{i}.attn",
            )
            for i in range(config.num_hidden_layers)
        ]

        # Model modifications
        self.replace_vocab_embed_class(self.model)

        # ForCausalLM modifications
        self.lm_head = ParallelLMHead(
            config.vocab_size,
            config.hidden_size,
            quant_config=self.quant_config,
            prefix=maybe_prefix(prefix, "lm_head"),
        )
        if config.tie_word_embeddings:
            self.lm_head.weight = self.model.get_input_embeddings().weight

        self.logits_processor = LogitsProcessor(config)

    def log_replacement(self, name: str, old_module: nn.Module, new_module: nn.Module):
        logger.debug("%s: %s -> %s", name, old_module, new_module)

    def tensor_parallel(self, tp_size: int):
        """
        Apply the model's tensor parallelization plan.
        Currently only supports linear layers.
        """
Stefan He's avatar
Stefan He committed
214
        tp_plan = getattr(self.model.config, "base_model_tp_plan", None) or {}
215

Stefan He's avatar
Stefan He committed
216
        if not tp_plan and self.tp_size > 1:
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
            raise ValueError(
                f"{type(self.model)} does not support tensor parallel yet!"
            )

        def _tensor_parallel(module: nn.Module, prefix: str = ""):
            for child_name, child_module in module.named_children():
                qual_name = maybe_prefix(prefix, child_name)
                for pattern, style in tp_plan.items():
                    if re.match(pattern, qual_name) and isinstance(
                        child_module, nn.Linear
                    ):
                        new_module = replace_linear_class(
                            child_module, style, self.quant_config
                        )
                        setattr(module, child_name, new_module)
                        self.log_replacement(qual_name, child_module, new_module)
                else:
                    _tensor_parallel(child_module, prefix=qual_name)

        _tensor_parallel(self.model)

    def replace_vocab_embed_class(self, module: nn.Module):
        # Use native set input embeddings
        new_module = VocabParallelEmbedding(
            self.vocab_size,
            self.config.hidden_size,
            org_num_embeddings=self.config.vocab_size,
            quant_config=None,
        )
        self.log_replacement(
            "input embedding", self.model.get_input_embeddings(), new_module
        )
        self.model.set_input_embeddings(new_module)

    @torch.no_grad()
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        forward_batch: ForwardBatch,
        input_embeds: torch.Tensor = None,
        get_embedding: bool = False,
    ) -> LogitsProcessorOutput:
        assert get_embedding is False, "embedding is not supported yet"
        aux_hidden_states = None
        hidden_states = self.model(
            input_ids[None, ...],
            use_cache=False,
            position_ids=positions[None, ...],
            forward_batch=forward_batch,
            attention_instances=self.attention_instances,
            return_dict=False,
        )[0][
            0, ...
        ]  # we remove batch dimension for now

        return self.logits_processor(
            input_ids, hidden_states, self.lm_head, forward_batch, aux_hidden_states
        )

    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
        params_dict = dict(self.named_parameters())
        for name, loaded_weight in weights:
            if name not in params_dict:
                name = f"{self.model.base_model_prefix}.{name}"
            if name in params_dict:
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                weight_loader(param, loaded_weight)


EntryClass = [TransformersForCausalLM]