"vscode:/vscode.git/clone" did not exist on "b9ff4f2a8dffc84b2ce226e7e98c33756caf098f"
commandr.py 17.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# Copyright 2024 Cohere and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# This file is based on the LLama model definition file in transformers
"""PyTorch Cohere model."""
25

26
from collections.abc import Iterable
27
from itertools import islice
28
from typing import Optional, Union
29
30
31

import torch
from torch import nn
32
from transformers import Cohere2Config, CohereConfig
33

34
from vllm.attention import Attention
35
from vllm.compilation.decorators import support_torch_compile
36
from vllm.config import CacheConfig, VllmConfig
37
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
38
from vllm.model_executor.layers.activation import SiluAndMul
39
40
41
42
43
from vllm.model_executor.layers.linear import (
    MergedColumnParallelLinear,
    QKVParallelLinear,
    RowParallelLinear,
)
44
from vllm.model_executor.layers.logits_processor import LogitsProcessor
45
from vllm.model_executor.layers.quantization import QuantizationConfig
46
from vllm.model_executor.layers.rotary_embedding import get_rope
47
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
48
from vllm.model_executor.model_loader.weight_utils import (
49
50
51
52
    default_weight_loader,
    maybe_remap_kv_scale_name,
    row_parallel_weight_loader,
)
53
from vllm.model_executor.utils import set_weight_attrs
54
from vllm.platforms import current_platform
55
from vllm.sequence import IntermediateTensors
56

57
from .interfaces import SupportsLoRA, SupportsPP, SupportsQuant
58
59
60
61
62
63
64
65
from .utils import (
    AutoWeightsLoader,
    extract_layer_index,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
66

67

68
@torch.compile(backend=current_platform.simple_compile_backend)
69
70
71
72
73
def layer_norm_func(hidden_states, weight, variance_epsilon):
    input_dtype = hidden_states.dtype
    hidden_states = hidden_states.to(torch.float32)
    mean = hidden_states.mean(-1, keepdim=True)
    variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True)
74
    hidden_states = (hidden_states - mean) * torch.rsqrt(variance + variance_epsilon)
75
76
77
78
    hidden_states = weight.to(torch.float32) * hidden_states
    return hidden_states.to(input_dtype)


79
class LayerNorm(nn.Module):
80
    def __init__(self, param_shape=None, eps=1e-5):
81
        super().__init__()
82
        self.weight = nn.Parameter(torch.ones(param_shape))
83
        self.variance_epsilon = eps
84
        set_weight_attrs(self.weight, {"weight_loader": row_parallel_weight_loader})
85
86

    def forward(self, hidden_states, residuals=None):
87
88
89
        hidden_states = layer_norm_func(
            hidden_states, self.weight, self.variance_epsilon
        )
90
        return hidden_states, residuals
91
92
93
94
95
96


# Copied from transformers.models.llama.modeling_llama.LlamaMLP Llama->Cohere
class CohereMLP(nn.Module):
    def __init__(
        self,
97
        config: Union[CohereConfig, Cohere2Config],
98
        quant_config: Optional[QuantizationConfig] = None,
99
        prefix: str = "",
100
101
102
103
104
105
106
107
108
    ):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.intermediate_size = config.intermediate_size
        self.gate_up_proj = MergedColumnParallelLinear(
            self.hidden_size,
            [self.intermediate_size] * 2,
            bias=False,
109
            quant_config=quant_config,
110
            prefix=f"{prefix}.gate_up_proj",
111
112
113
114
115
        )
        self.down_proj = RowParallelLinear(
            self.intermediate_size,
            self.hidden_size,
            bias=False,
116
            quant_config=quant_config,
117
            prefix=f"{prefix}.down_proj",
118
119
120
121
122
123
124
125
126
127
128
129
130
        )
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


class CohereAttention(nn.Module):
    def __init__(
        self,
131
        config: Union[CohereConfig, Cohere2Config],
132
        cache_config: Optional[CacheConfig] = None,
133
        quant_config: Optional[QuantizationConfig] = None,
134
        prefix: str = "",
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
    ):
        super().__init__()
        tp_size = get_tensor_model_parallel_world_size()
        self.config = config
        self.attention_dropout = config.attention_dropout
        self.hidden_size = config.hidden_size
        self.total_num_heads = config.num_attention_heads
        self.num_heads = self.total_num_heads // tp_size
        self.head_dim = self.hidden_size // self.total_num_heads
        self.total_num_kv_heads = config.num_key_value_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
157
        self.max_position_embeddings = getattr(
158
159
            config, "model_max_length", None
        ) or getattr(config, "max_position_embeddings", 8192)
160
161
        self.rope_theta = config.rope_theta
        self.rope_scaling = getattr(config, "rope_scaling", None)
162
        self.use_qk_norm = getattr(config, "use_qk_norm", False)
163
164
165
166
167
168
        self.qkv_proj = QKVParallelLinear(
            self.hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=False,
169
            quant_config=quant_config,
170
            prefix=f"{prefix}.qkv_proj",
171
172
173
174
175
        )
        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            self.hidden_size,
            bias=False,
176
            quant_config=quant_config,
177
            prefix=f"{prefix}.o_proj",
178
179
180
181
182
183
184
185
186
        )
        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=self.max_position_embeddings,
            base=self.rope_theta,
            rope_scaling=self.rope_scaling,
            is_neox_style=False,
        )
187

188
        # Model v2 has interleaved sliding windows, v1 does not
189
190
191
192
193
194
195
        self.v1 = isinstance(config, CohereConfig)

        self.sliding_window = None
        if not self.v1:
            layer_idx = extract_layer_index(prefix)
            if config.layer_types[layer_idx] == "sliding_attention":
                self.sliding_window = config.sliding_window
196

197
198
199
200
201
202
203
204
205
206
        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            per_layer_sliding_window=self.sliding_window,
            prefix=f"{prefix}.attn",
        )
207
        if self.use_qk_norm:
208
209
210
211
212
213
214
            self.q_norm = LayerNorm(
                param_shape=(self.num_heads, self.head_dim), eps=config.layer_norm_eps
            )
            self.k_norm = LayerNorm(
                param_shape=(self.num_kv_heads, self.head_dim),
                eps=config.layer_norm_eps,
            )
215
216
217
218
219
220
221
222
223

    def _apply_qk_norm(self, q, k):
        q = q.view(*q.shape[:-1], -1, self.head_dim)
        k = k.view(*k.shape[:-1], -1, self.head_dim)
        q, _ = self.q_norm(q)
        k, _ = self.k_norm(k)
        q = q.view(*q.shape[:-2], -1)
        k = k.view(*k.shape[:-2], -1)
        return q, k
224
225
226
227
228
229
230
231

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
232
233
        if self.use_qk_norm:
            q, k = self._apply_qk_norm(q, k)
234
235
        if self.v1 or self.sliding_window:
            q, k = self.rotary_emb(positions, q, k)
236
        attn_output = self.attn(q, k, v)
237
238
239
240
241
        output, _ = self.o_proj(attn_output)
        return output


class CohereDecoderLayer(nn.Module):
242
243
244
245
246
247
248
    def __init__(
        self,
        config: Union[CohereConfig, Cohere2Config],
        cache_config: Optional[CacheConfig] = None,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ):
249
250
251
        super().__init__()
        self.hidden_size = config.hidden_size

252
253
254
255
256
257
        self.self_attn = CohereAttention(
            config,
            cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.self_attn",
        )
258

259
260
261
262
        self.mlp = CohereMLP(config, quant_config=quant_config, prefix=f"{prefix}.mlp")
        self.input_layernorm = LayerNorm(
            param_shape=(config.hidden_size), eps=config.layer_norm_eps
        )
263
264
265
266
267
268

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        residual: Optional[torch.Tensor],
269
    ) -> tuple[torch.Tensor, torch.Tensor]:
270
271
272
273
274
275
276
277
278
279
280
281
282
283
        # Self Attention
        residual = hidden_states
        hidden_states, residual = self.input_layernorm(hidden_states, residual)
        hidden_states_attention = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
        )
        hidden_states_mlp = self.mlp(hidden_states)
        # Add everything together
        hidden_states = residual + hidden_states_attention + hidden_states_mlp

        return hidden_states, residual


284
@support_torch_compile
285
class CohereModel(nn.Module):
286
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
287
        super().__init__()
288
289
290
291
292

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config
293
        self.quant_config = quant_config
294

295
        self.config = config
296
297
298
299
300
        lora_vocab = (
            (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1))
            if lora_config
            else 0
        )
301
302
        self.vocab_size = config.vocab_size + lora_vocab
        self.org_vocab_size = config.vocab_size
303
304
305
        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size, config.hidden_size
        )
306
307
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
308
            lambda prefix: CohereDecoderLayer(
309
310
311
312
313
314
315
316
317
318
                config, cache_config, quant_config, prefix=prefix
            ),
            prefix=f"{prefix}.layers",
        )
        self.norm = LayerNorm(
            param_shape=(config.hidden_size), eps=config.layer_norm_eps
        )
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], config.hidden_size
        )
319

320
321
322
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

323
324
325
326
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
327
        intermediate_tensors: Optional[IntermediateTensors],
328
        inputs_embeds: Optional[torch.Tensor] = None,
329
330
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
331
332
333
334
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
335
336
337
338
339
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
340
        for layer in islice(self.layers, self.start_layer, self.end_layer):
341
342
343
344
345
            hidden_states, residual = layer(
                positions,
                hidden_states,
                residual,
            )
346
        if not get_pp_group().is_last_rank:
347
348
349
            return IntermediateTensors(
                {"hidden_states": hidden_states, "residual": residual}
            )
350
351
352
        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states

353
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
354
355
356
357
358
359
360
361
362
363
364
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()
        for name, loaded_weight in weights:
365
366
367
            if self.quant_config is not None and (
                scale_name := self.quant_config.get_cache_scale(name)
            ):
368
369
                # Loading kv cache quantization scales
                param = params_dict[scale_name]
370
371
372
373
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                loaded_weight = (
                    loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
                )
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
                weight_loader(param, loaded_weight)
                loaded_params.add(scale_name)
                continue

            for param_name, shard_name, shard_id in stacked_params_mapping:
                if shard_name not in name:
                    continue
                name = name.replace(shard_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                # Remapping the name of FP8 kv-scale.
                name = maybe_remap_kv_scale_name(name, params_dict)
                if name is None:
                    continue

                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
403
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
404
405
406
407
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

408

409
class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsQuant):
410
411
412
413
414
415
416
417
418
419
420
421
422
423
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }
    # LoRA specific attributes
    embedding_modules = {"embed_tokens": "input_embeddings"}

424
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
425
        super().__init__()
426
427
428
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config
429
        self.config = config
430
431
432
        # currently all existing command R models have `tie_word_embeddings`
        # enabled
        assert config.tie_word_embeddings
433
434
435
        self.unpadded_vocab_size = config.vocab_size
        if lora_config:
            self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
436
        self.quant_config = quant_config
437
438
439
440
441
442
        self.logits_processor = LogitsProcessor(
            self.unpadded_vocab_size, config.vocab_size, scale=config.logit_scale
        )
        self.model = CohereModel(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )
443
        self.make_empty_intermediate_tensors = (
444
445
            self.model.make_empty_intermediate_tensors
        )
446

447
448
449
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

450
451
452
453
454
    @torch.no_grad()
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
455
        intermediate_tensors: Optional[IntermediateTensors] = None,
456
        inputs_embeds: Optional[torch.Tensor] = None,
457
    ) -> Union[torch.Tensor, IntermediateTensors]:
458
459
460
        hidden_states = self.model(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
461
462
        return hidden_states

463
464
465
466
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> Optional[torch.Tensor]:
467
        is_not_lora = hasattr(self.model.embed_tokens, "weight")
468
        if is_not_lora:
469
            logits = self.logits_processor(self.model.embed_tokens, hidden_states)
470
        else:
471
472
473
            logits = self.logits_processor(
                self.model.embed_tokens.base_layer, hidden_states
            )
474

475
476
        return logits

477
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
478
        loader = AutoWeightsLoader(
479
480
            self, skip_prefixes=["lm_head", "rotary_emb.inv_freq"]
        )
481
        return loader.load_weights(weights)