commandr.py 17.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# Copyright 2024 Cohere and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# This file is based on the LLama model definition file in transformers
"""PyTorch Cohere model."""
25

26
from collections.abc import Iterable
27
from itertools import islice
28
29
30

import torch
from torch import nn
31
from transformers import Cohere2Config, CohereConfig
32

33
from vllm.attention import Attention
34
from vllm.compilation.decorators import support_torch_compile
35
from vllm.config import CacheConfig, VllmConfig
36
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
37
from vllm.model_executor.layers.activation import SiluAndMul
38
39
40
41
42
from vllm.model_executor.layers.linear import (
    MergedColumnParallelLinear,
    QKVParallelLinear,
    RowParallelLinear,
)
43
from vllm.model_executor.layers.logits_processor import LogitsProcessor
44
from vllm.model_executor.layers.quantization import QuantizationConfig
45
from vllm.model_executor.layers.rotary_embedding import get_rope
46
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
47
from vllm.model_executor.model_loader.weight_utils import (
48
49
50
51
    default_weight_loader,
    maybe_remap_kv_scale_name,
    row_parallel_weight_loader,
)
52
from vllm.model_executor.utils import set_weight_attrs
53
from vllm.platforms import current_platform
54
from vllm.sequence import IntermediateTensors
55

56
from .interfaces import SupportsLoRA, SupportsPP, SupportsQuant
57
58
59
60
61
62
63
64
from .utils import (
    AutoWeightsLoader,
    extract_layer_index,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
65

66

67
@torch.compile(backend=current_platform.simple_compile_backend)
68
69
70
71
72
def layer_norm_func(hidden_states, weight, variance_epsilon):
    input_dtype = hidden_states.dtype
    hidden_states = hidden_states.to(torch.float32)
    mean = hidden_states.mean(-1, keepdim=True)
    variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True)
73
    hidden_states = (hidden_states - mean) * torch.rsqrt(variance + variance_epsilon)
74
75
76
77
    hidden_states = weight.to(torch.float32) * hidden_states
    return hidden_states.to(input_dtype)


78
class LayerNorm(nn.Module):
79
    def __init__(self, param_shape=None, eps=1e-5):
80
        super().__init__()
81
        self.weight = nn.Parameter(torch.ones(param_shape))
82
        self.variance_epsilon = eps
83
        set_weight_attrs(self.weight, {"weight_loader": row_parallel_weight_loader})
84
85

    def forward(self, hidden_states, residuals=None):
86
87
88
        hidden_states = layer_norm_func(
            hidden_states, self.weight, self.variance_epsilon
        )
89
        return hidden_states, residuals
90
91
92
93
94
95


# Copied from transformers.models.llama.modeling_llama.LlamaMLP Llama->Cohere
class CohereMLP(nn.Module):
    def __init__(
        self,
96
97
        config: CohereConfig | Cohere2Config,
        quant_config: QuantizationConfig | None = None,
98
        prefix: str = "",
99
100
101
102
103
104
105
106
107
    ):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.intermediate_size = config.intermediate_size
        self.gate_up_proj = MergedColumnParallelLinear(
            self.hidden_size,
            [self.intermediate_size] * 2,
            bias=False,
108
            quant_config=quant_config,
109
            prefix=f"{prefix}.gate_up_proj",
110
111
112
113
114
        )
        self.down_proj = RowParallelLinear(
            self.intermediate_size,
            self.hidden_size,
            bias=False,
115
            quant_config=quant_config,
116
            prefix=f"{prefix}.down_proj",
117
118
119
120
121
122
123
124
125
126
127
128
129
        )
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


class CohereAttention(nn.Module):
    def __init__(
        self,
130
131
132
        config: CohereConfig | Cohere2Config,
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
133
        prefix: str = "",
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
    ):
        super().__init__()
        tp_size = get_tensor_model_parallel_world_size()
        self.config = config
        self.attention_dropout = config.attention_dropout
        self.hidden_size = config.hidden_size
        self.total_num_heads = config.num_attention_heads
        self.num_heads = self.total_num_heads // tp_size
        self.head_dim = self.hidden_size // self.total_num_heads
        self.total_num_kv_heads = config.num_key_value_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
156
        self.max_position_embeddings = getattr(
157
158
            config, "model_max_length", None
        ) or getattr(config, "max_position_embeddings", 8192)
159
160
        self.rope_theta = config.rope_theta
        self.rope_scaling = getattr(config, "rope_scaling", None)
161
        self.use_qk_norm = getattr(config, "use_qk_norm", False)
162
163
164
165
166
167
        self.qkv_proj = QKVParallelLinear(
            self.hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=False,
168
            quant_config=quant_config,
169
            prefix=f"{prefix}.qkv_proj",
170
171
172
173
174
        )
        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            self.hidden_size,
            bias=False,
175
            quant_config=quant_config,
176
            prefix=f"{prefix}.o_proj",
177
178
179
180
181
182
183
184
185
        )
        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=self.max_position_embeddings,
            base=self.rope_theta,
            rope_scaling=self.rope_scaling,
            is_neox_style=False,
        )
186

187
        # Model v2 has interleaved sliding windows, v1 does not
188
189
190
191
192
193
194
        self.v1 = isinstance(config, CohereConfig)

        self.sliding_window = None
        if not self.v1:
            layer_idx = extract_layer_index(prefix)
            if config.layer_types[layer_idx] == "sliding_attention":
                self.sliding_window = config.sliding_window
195

196
197
198
199
200
201
202
203
204
205
        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            per_layer_sliding_window=self.sliding_window,
            prefix=f"{prefix}.attn",
        )
206
        if self.use_qk_norm:
207
208
209
210
211
212
213
            self.q_norm = LayerNorm(
                param_shape=(self.num_heads, self.head_dim), eps=config.layer_norm_eps
            )
            self.k_norm = LayerNorm(
                param_shape=(self.num_kv_heads, self.head_dim),
                eps=config.layer_norm_eps,
            )
214
215
216
217
218
219
220
221
222

    def _apply_qk_norm(self, q, k):
        q = q.view(*q.shape[:-1], -1, self.head_dim)
        k = k.view(*k.shape[:-1], -1, self.head_dim)
        q, _ = self.q_norm(q)
        k, _ = self.k_norm(k)
        q = q.view(*q.shape[:-2], -1)
        k = k.view(*k.shape[:-2], -1)
        return q, k
223
224
225
226
227
228
229
230

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
231
232
        if self.use_qk_norm:
            q, k = self._apply_qk_norm(q, k)
233
234
        if self.v1 or self.sliding_window:
            q, k = self.rotary_emb(positions, q, k)
235
        attn_output = self.attn(q, k, v)
236
237
238
239
240
        output, _ = self.o_proj(attn_output)
        return output


class CohereDecoderLayer(nn.Module):
241
242
    def __init__(
        self,
243
244
245
        config: CohereConfig | Cohere2Config,
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
246
247
        prefix: str = "",
    ):
248
249
250
        super().__init__()
        self.hidden_size = config.hidden_size

251
252
253
254
255
256
        self.self_attn = CohereAttention(
            config,
            cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.self_attn",
        )
257

258
259
260
261
        self.mlp = CohereMLP(config, quant_config=quant_config, prefix=f"{prefix}.mlp")
        self.input_layernorm = LayerNorm(
            param_shape=(config.hidden_size), eps=config.layer_norm_eps
        )
262
263
264
265
266

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
267
        residual: torch.Tensor | None,
268
    ) -> tuple[torch.Tensor, torch.Tensor]:
269
270
271
272
273
274
275
276
277
278
279
280
281
282
        # Self Attention
        residual = hidden_states
        hidden_states, residual = self.input_layernorm(hidden_states, residual)
        hidden_states_attention = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
        )
        hidden_states_mlp = self.mlp(hidden_states)
        # Add everything together
        hidden_states = residual + hidden_states_attention + hidden_states_mlp

        return hidden_states, residual


283
@support_torch_compile
284
class CohereModel(nn.Module):
285
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
286
        super().__init__()
287
288
289
290

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
291
        self.quant_config = quant_config
292

293
        self.config = config
294
295
296

        self.vocab_size = config.vocab_size

297
298
299
        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size, config.hidden_size
        )
300
301
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
302
            lambda prefix: CohereDecoderLayer(
303
304
305
306
307
308
309
310
311
312
                config, cache_config, quant_config, prefix=prefix
            ),
            prefix=f"{prefix}.layers",
        )
        self.norm = LayerNorm(
            param_shape=(config.hidden_size), eps=config.layer_norm_eps
        )
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], config.hidden_size
        )
313

314
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
315
316
        return self.embed_tokens(input_ids)

317
318
319
320
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
321
322
323
        intermediate_tensors: IntermediateTensors | None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
324
        if get_pp_group().is_first_rank:
325
326
327
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
328
                hidden_states = self.embed_input_ids(input_ids)
329
330
331
332
333
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
334
        for layer in islice(self.layers, self.start_layer, self.end_layer):
335
336
337
338
339
            hidden_states, residual = layer(
                positions,
                hidden_states,
                residual,
            )
340
        if not get_pp_group().is_last_rank:
341
342
343
            return IntermediateTensors(
                {"hidden_states": hidden_states, "residual": residual}
            )
344
345
346
        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states

347
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
348
349
350
351
352
353
354
355
356
357
358
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()
        for name, loaded_weight in weights:
359
360
361
            if self.quant_config is not None and (
                scale_name := self.quant_config.get_cache_scale(name)
            ):
362
363
                # Loading kv cache quantization scales
                param = params_dict[scale_name]
364
365
366
367
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                loaded_weight = (
                    loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
                )
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
                weight_loader(param, loaded_weight)
                loaded_params.add(scale_name)
                continue

            for param_name, shard_name, shard_id in stacked_params_mapping:
                if shard_name not in name:
                    continue
                name = name.replace(shard_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                # Remapping the name of FP8 kv-scale.
                name = maybe_remap_kv_scale_name(name, params_dict)
                if name is None:
                    continue

                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
397
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
398
399
400
401
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

402

403
class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsQuant):
404
405
406
407
408
409
410
411
412
413
414
415
416
417
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }
    # LoRA specific attributes
    embedding_modules = {"embed_tokens": "input_embeddings"}

418
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
419
        super().__init__()
420
421
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
422

423
        self.config = config
424
425
426
        # currently all existing command R models have `tie_word_embeddings`
        # enabled
        assert config.tie_word_embeddings
427

428
        self.quant_config = quant_config
429
        self.logits_processor = LogitsProcessor(
430
            config.vocab_size, scale=config.logit_scale
431
432
433
434
        )
        self.model = CohereModel(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )
435
        self.make_empty_intermediate_tensors = (
436
437
            self.model.make_empty_intermediate_tensors
        )
438

439
440
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.embed_input_ids(input_ids)
441

442
443
444
445
446
    @torch.no_grad()
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
447
448
449
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
450
451
452
        hidden_states = self.model(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
453
454
        return hidden_states

455
456
457
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
458
    ) -> torch.Tensor | None:
459
        is_not_lora = hasattr(self.model.embed_tokens, "weight")
460
        if is_not_lora:
461
            logits = self.logits_processor(self.model.embed_tokens, hidden_states)
462
        else:
463
464
465
            logits = self.logits_processor(
                self.model.embed_tokens.base_layer, hidden_states
            )
466

467
468
        return logits

469
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
470
        loader = AutoWeightsLoader(
471
472
            self, skip_prefixes=["lm_head", "rotary_emb.inv_freq"]
        )
473
        return loader.load_weights(weights)