commandr.py 18.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# Copyright 2024 Cohere and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# This file is based on the LLama model definition file in transformers
"""PyTorch Cohere model."""
25
from collections.abc import Iterable
26
from itertools import islice
27
from typing import Optional, Union
28
29
30

import torch
from torch import nn
31
from transformers import Cohere2Config, CohereConfig
32

33
from vllm.attention import Attention
34
from vllm.compilation.decorators import support_torch_compile
35
from vllm.config import CacheConfig, VllmConfig
36
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
37
from vllm.model_executor.layers.activation import SiluAndMul
38
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
39
40
41
                                               QKVParallelLinear,
                                               RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
42
from vllm.model_executor.layers.quantization import QuantizationConfig
43
44
45
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
    VocabParallelEmbedding)
46
from vllm.model_executor.model_loader.weight_utils import (
47
48
    default_weight_loader, maybe_remap_kv_scale_name,
    row_parallel_weight_loader)
49
from vllm.model_executor.sampling_metadata import SamplingMetadata
50
from vllm.model_executor.utils import set_weight_attrs
51
from vllm.platforms import current_platform
52
from vllm.sequence import IntermediateTensors
53

54
from .interfaces import SupportsLoRA, SupportsPP, SupportsQuant
55
56
from .utils import (AutoWeightsLoader, extract_layer_index,
                    is_pp_missing_parameter,
57
58
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
59

60

61
@torch.compile(backend=current_platform.simple_compile_backend)
62
63
64
65
66
67
68
69
70
71
72
def layer_norm_func(hidden_states, weight, variance_epsilon):
    input_dtype = hidden_states.dtype
    hidden_states = hidden_states.to(torch.float32)
    mean = hidden_states.mean(-1, keepdim=True)
    variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True)
    hidden_states = (hidden_states - mean) * torch.rsqrt(variance +
                                                         variance_epsilon)
    hidden_states = weight.to(torch.float32) * hidden_states
    return hidden_states.to(input_dtype)


73
74
class LayerNorm(nn.Module):

75
    def __init__(self, param_shape=None, eps=1e-5):
76
        super().__init__()
77
        self.weight = nn.Parameter(torch.ones(param_shape))
78
        self.variance_epsilon = eps
79
80
        set_weight_attrs(self.weight,
                         {"weight_loader": row_parallel_weight_loader})
81
82

    def forward(self, hidden_states, residuals=None):
83
84
85
        hidden_states = layer_norm_func(hidden_states, self.weight,
                                        self.variance_epsilon)
        return hidden_states, residuals
86
87
88
89
90
91
92


# Copied from transformers.models.llama.modeling_llama.LlamaMLP Llama->Cohere
class CohereMLP(nn.Module):

    def __init__(
        self,
93
        config: Union[CohereConfig, Cohere2Config],
94
        quant_config: Optional[QuantizationConfig] = None,
95
        prefix: str = "",
96
97
98
99
100
101
102
103
104
    ):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.intermediate_size = config.intermediate_size
        self.gate_up_proj = MergedColumnParallelLinear(
            self.hidden_size,
            [self.intermediate_size] * 2,
            bias=False,
105
            quant_config=quant_config,
106
            prefix=f"{prefix}.gate_up_proj",
107
108
109
110
111
        )
        self.down_proj = RowParallelLinear(
            self.intermediate_size,
            self.hidden_size,
            bias=False,
112
            quant_config=quant_config,
113
            prefix=f"{prefix}.down_proj",
114
115
116
117
118
119
120
121
122
123
124
125
126
127
        )
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


class CohereAttention(nn.Module):

    def __init__(
        self,
128
        config: Union[CohereConfig, Cohere2Config],
129
        cache_config: Optional[CacheConfig] = None,
130
        quant_config: Optional[QuantizationConfig] = None,
131
        prefix: str = "",
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
    ):
        super().__init__()
        tp_size = get_tensor_model_parallel_world_size()
        self.config = config
        self.attention_dropout = config.attention_dropout
        self.hidden_size = config.hidden_size
        self.total_num_heads = config.num_attention_heads
        self.num_heads = self.total_num_heads // tp_size
        self.head_dim = self.hidden_size // self.total_num_heads
        self.total_num_kv_heads = config.num_key_value_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
154
155
156
        self.max_position_embeddings = getattr(
            config, "model_max_length", None) or getattr(
                config, "max_position_embeddings", 8192)
157
158
        self.rope_theta = config.rope_theta
        self.rope_scaling = getattr(config, "rope_scaling", None)
159
        self.use_qk_norm = getattr(config, "use_qk_norm", False)
160
161
162
163
164
165
        self.qkv_proj = QKVParallelLinear(
            self.hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=False,
166
            quant_config=quant_config,
167
            prefix=f"{prefix}.qkv_proj",
168
169
170
171
172
        )
        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            self.hidden_size,
            bias=False,
173
            quant_config=quant_config,
174
            prefix=f"{prefix}.o_proj",
175
176
177
178
179
180
181
182
183
        )
        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=self.max_position_embeddings,
            base=self.rope_theta,
            rope_scaling=self.rope_scaling,
            is_neox_style=False,
        )
184

185
        # Model v2 has interleaved sliding windows, v1 does not
186
187
188
189
190
191
192
        self.v1 = isinstance(config, CohereConfig)

        self.sliding_window = None
        if not self.v1:
            layer_idx = extract_layer_index(prefix)
            if config.layer_types[layer_idx] == "sliding_attention":
                self.sliding_window = config.sliding_window
193

194
195
196
197
198
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              self.scaling,
                              num_kv_heads=self.num_kv_heads,
                              cache_config=cache_config,
199
                              quant_config=quant_config,
200
                              per_layer_sliding_window=self.sliding_window,
201
                              prefix=f"{prefix}.attn")
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
        if self.use_qk_norm:
            self.q_norm = LayerNorm(param_shape=(self.num_heads,
                                                 self.head_dim),
                                    eps=config.layer_norm_eps)
            self.k_norm = LayerNorm(param_shape=(self.num_kv_heads,
                                                 self.head_dim),
                                    eps=config.layer_norm_eps)

    def _apply_qk_norm(self, q, k):
        q = q.view(*q.shape[:-1], -1, self.head_dim)
        k = k.view(*k.shape[:-1], -1, self.head_dim)
        q, _ = self.q_norm(q)
        k, _ = self.k_norm(k)
        q = q.view(*q.shape[:-2], -1)
        k = k.view(*k.shape[:-2], -1)
        return q, k
218
219
220
221
222
223
224
225

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
226
227
        if self.use_qk_norm:
            q, k = self._apply_qk_norm(q, k)
228
229
        if self.v1 or self.sliding_window:
            q, k = self.rotary_emb(positions, q, k)
230
        attn_output = self.attn(q, k, v)
231
232
233
234
235
236
237
        output, _ = self.o_proj(attn_output)
        return output


class CohereDecoderLayer(nn.Module):

    def __init__(self,
238
                 config: Union[CohereConfig, Cohere2Config],
239
                 cache_config: Optional[CacheConfig] = None,
240
241
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
242
243
244
        super().__init__()
        self.hidden_size = config.hidden_size

245
246
        self.self_attn = CohereAttention(config,
                                         cache_config,
247
248
                                         quant_config=quant_config,
                                         prefix=f"{prefix}.self_attn")
249

250
251
252
        self.mlp = CohereMLP(config,
                             quant_config=quant_config,
                             prefix=f"{prefix}.mlp")
253
        self.input_layernorm = LayerNorm(param_shape=(config.hidden_size),
254
255
256
257
258
259
260
                                         eps=config.layer_norm_eps)

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        residual: Optional[torch.Tensor],
261
    ) -> tuple[torch.Tensor, torch.Tensor]:
262
263
264
265
266
267
268
269
270
271
272
273
274
275
        # Self Attention
        residual = hidden_states
        hidden_states, residual = self.input_layernorm(hidden_states, residual)
        hidden_states_attention = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
        )
        hidden_states_mlp = self.mlp(hidden_states)
        # Add everything together
        hidden_states = residual + hidden_states_attention + hidden_states_mlp

        return hidden_states, residual


276
@support_torch_compile
277
278
class CohereModel(nn.Module):

279
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
280
        super().__init__()
281
282
283
284
285

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config
286
        self.quant_config = quant_config
287

288
        self.config = config
289
290
291
292
        lora_vocab = (lora_config.lora_extra_vocab_size *
                      (lora_config.max_loras or 1)) if lora_config else 0
        self.vocab_size = config.vocab_size + lora_vocab
        self.org_vocab_size = config.vocab_size
293
294
        self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
                                                   config.hidden_size)
295
296
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
297
298
            lambda prefix: CohereDecoderLayer(
                config, cache_config, quant_config, prefix=prefix),
299
            prefix=f"{prefix}.layers")
300
301
        self.norm = LayerNorm(param_shape=(config.hidden_size),
                              eps=config.layer_norm_eps)
302
303
304
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(
                ["hidden_states", "residual"], config.hidden_size))
305

306
307
308
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

309
310
311
312
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
313
        intermediate_tensors: Optional[IntermediateTensors],
314
        inputs_embeds: Optional[torch.Tensor] = None,
315
316
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
317
318
319
320
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
321
322
323
324
325
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
326
        for layer in islice(self.layers, self.start_layer, self.end_layer):
327
328
329
330
331
            hidden_states, residual = layer(
                positions,
                hidden_states,
                residual,
            )
332
333
334
335
336
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({
                "hidden_states": hidden_states,
                "residual": residual
            })
337
338
339
        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states

340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()
        for name, loaded_weight in weights:
            if (self.quant_config is not None and
                (scale_name := self.quant_config.get_cache_scale(name))):
                # Loading kv cache quantization scales
                param = params_dict[scale_name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
                                 loaded_weight[0])
                weight_loader(param, loaded_weight)
                loaded_params.add(scale_name)
                continue

            for param_name, shard_name, shard_id in stacked_params_mapping:
                if shard_name not in name:
                    continue
                name = name.replace(shard_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                # Remapping the name of FP8 kv-scale.
                name = maybe_remap_kv_scale_name(name, params_dict)
                if name is None:
                    continue

                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

396

397
class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsQuant):
398
399
400
401
402
403
404
405
406
407
408
409
410
411
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }
    # LoRA specific attributes
    embedding_modules = {"embed_tokens": "input_embeddings"}

412
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
413
        super().__init__()
414
415
416
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config
417
        self.config = config
418
419
420
        # currently all existing command R models have `tie_word_embeddings`
        # enabled
        assert config.tie_word_embeddings
421
422
423
        self.unpadded_vocab_size = config.vocab_size
        if lora_config:
            self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
424
        self.quant_config = quant_config
425
426
        self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                config.vocab_size,
427
                                                scale=config.logit_scale)
428
429
        self.model = CohereModel(vllm_config=vllm_config,
                                 prefix=maybe_prefix(prefix, "model"))
430
431
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)
432

433
434
435
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

436
437
438
439
440
    @torch.no_grad()
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
441
        intermediate_tensors: Optional[IntermediateTensors] = None,
442
        inputs_embeds: Optional[torch.Tensor] = None,
443
    ) -> Union[torch.Tensor, IntermediateTensors]:
444
        hidden_states = self.model(input_ids, positions, intermediate_tensors,
445
                                   inputs_embeds)
446
447
        return hidden_states

448
449
450
451
452
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
453
454
        is_not_lora = hasattr(self.model.embed_tokens, 'weight')
        if is_not_lora:
455
456
            logits = self.logits_processor(self.model.embed_tokens,
                                           hidden_states, sampling_metadata)
457
        else:
458
459
            logits = self.logits_processor(self.model.embed_tokens.base_layer,
                                           hidden_states, sampling_metadata)
460

461
462
        return logits

463
464
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
465
466
467
        loader = AutoWeightsLoader(
            self, skip_prefixes=["lm_head", "rotary_emb.inv_freq"])
        return loader.load_weights(weights)