commandr.py 19 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# Copyright 2024 Cohere and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# This file is based on the LLama model definition file in transformers
"""PyTorch Cohere model."""
24
from typing import Iterable, Optional, Set, Tuple, Union
25
26
27
28
29

import torch
from torch import nn
from transformers import CohereConfig

30
from vllm.attention import Attention
31
from vllm.compilation.decorators import support_torch_compile
32
from vllm.config import CacheConfig, VllmConfig
33
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
34
from vllm.model_executor.layers.activation import SiluAndMul
35
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
36
37
38
                                               QKVParallelLinear,
                                               RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
39
from vllm.model_executor.layers.quantization import QuantizationConfig
40
from vllm.model_executor.layers.rotary_embedding import get_rope
Joe Runde's avatar
Joe Runde committed
41
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
42
43
from vllm.model_executor.layers.vocab_parallel_embedding import (
    VocabParallelEmbedding)
44
from vllm.model_executor.model_loader.weight_utils import (
45
46
    default_weight_loader, maybe_remap_kv_scale_name,
    row_parallel_weight_loader)
47
from vllm.model_executor.sampling_metadata import SamplingMetadata
48
from vllm.model_executor.utils import set_weight_attrs
49
from vllm.platforms import current_platform
50
from vllm.sequence import IntermediateTensors
51

52
from .interfaces import SupportsLoRA, SupportsPP, SupportsQuant
53
from .utils import (extract_layer_index, is_pp_missing_parameter,
54
55
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
56

57

58
@torch.compile(backend=current_platform.simple_compile_backend)
59
60
61
62
63
64
65
66
67
68
69
def layer_norm_func(hidden_states, weight, variance_epsilon):
    input_dtype = hidden_states.dtype
    hidden_states = hidden_states.to(torch.float32)
    mean = hidden_states.mean(-1, keepdim=True)
    variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True)
    hidden_states = (hidden_states - mean) * torch.rsqrt(variance +
                                                         variance_epsilon)
    hidden_states = weight.to(torch.float32) * hidden_states
    return hidden_states.to(input_dtype)


70
71
class LayerNorm(nn.Module):

72
    def __init__(self, param_shape=None, eps=1e-5):
73
        super().__init__()
74
        self.weight = nn.Parameter(torch.ones(param_shape))
75
        self.variance_epsilon = eps
76
77
        set_weight_attrs(self.weight,
                         {"weight_loader": row_parallel_weight_loader})
78
79

    def forward(self, hidden_states, residuals=None):
80
81
82
        hidden_states = layer_norm_func(hidden_states, self.weight,
                                        self.variance_epsilon)
        return hidden_states, residuals
83
84
85
86
87
88
89


# Copied from transformers.models.llama.modeling_llama.LlamaMLP Llama->Cohere
class CohereMLP(nn.Module):

    def __init__(
        self,
90
        config: CohereConfig,
91
        quant_config: Optional[QuantizationConfig] = None,
92
        prefix: str = "",
93
94
95
96
97
98
99
100
101
    ):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.intermediate_size = config.intermediate_size
        self.gate_up_proj = MergedColumnParallelLinear(
            self.hidden_size,
            [self.intermediate_size] * 2,
            bias=False,
102
            quant_config=quant_config,
103
            prefix=f"{prefix}.gate_up_proj",
104
105
106
107
108
        )
        self.down_proj = RowParallelLinear(
            self.intermediate_size,
            self.hidden_size,
            bias=False,
109
            quant_config=quant_config,
110
            prefix=f"{prefix}.down_proj",
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
        )
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


class CohereAttention(nn.Module):

    def __init__(
        self,
        config: CohereConfig,
126
        cache_config: Optional[CacheConfig] = None,
127
        quant_config: Optional[QuantizationConfig] = None,
128
        prefix: str = "",
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
    ):
        super().__init__()
        tp_size = get_tensor_model_parallel_world_size()
        self.config = config
        self.attention_dropout = config.attention_dropout
        self.hidden_size = config.hidden_size
        self.total_num_heads = config.num_attention_heads
        self.num_heads = self.total_num_heads // tp_size
        self.head_dim = self.hidden_size // self.total_num_heads
        self.total_num_kv_heads = config.num_key_value_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
151
152
153
        self.max_position_embeddings = getattr(
            config, "model_max_length", None) or getattr(
                config, "max_position_embeddings", 8192)
154
155
        self.rope_theta = config.rope_theta
        self.rope_scaling = getattr(config, "rope_scaling", None)
156
        self.use_qk_norm = getattr(config, "use_qk_norm", False)
157
158
159
160
161
162
        self.qkv_proj = QKVParallelLinear(
            self.hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=False,
163
            quant_config=quant_config,
164
            prefix=f"{prefix}.qkv_proj",
165
166
167
168
169
        )
        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            self.hidden_size,
            bias=False,
170
            quant_config=quant_config,
171
            prefix=f"{prefix}.o_proj",
172
173
174
175
176
177
178
179
180
        )
        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=self.max_position_embeddings,
            base=self.rope_theta,
            rope_scaling=self.rope_scaling,
            is_neox_style=False,
        )
181

182
183
184
185
186
        # Model v2 has interleaved sliding windows, v1 does not
        interleaved_sliding_window = getattr(config,
                                             "interleaved_sliding_window",
                                             None)
        self.v1 = interleaved_sliding_window is None
187
188
189
190
191
192

        layer_idx = extract_layer_index(prefix)
        layer_has_sliding_window = (
            getattr(config, "sliding_window_pattern", False)
            and (layer_idx + 1) % self.config.sliding_window_pattern != 0)

193
        self.sliding_window = (interleaved_sliding_window
194
195
                               if layer_has_sliding_window else None)

196
197
198
199
200
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              self.scaling,
                              num_kv_heads=self.num_kv_heads,
                              cache_config=cache_config,
201
                              quant_config=quant_config,
202
                              per_layer_sliding_window=self.sliding_window,
203
                              prefix=f"{prefix}.attn")
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
        if self.use_qk_norm:
            self.q_norm = LayerNorm(param_shape=(self.num_heads,
                                                 self.head_dim),
                                    eps=config.layer_norm_eps)
            self.k_norm = LayerNorm(param_shape=(self.num_kv_heads,
                                                 self.head_dim),
                                    eps=config.layer_norm_eps)

    def _apply_qk_norm(self, q, k):
        q = q.view(*q.shape[:-1], -1, self.head_dim)
        k = k.view(*k.shape[:-1], -1, self.head_dim)
        q, _ = self.q_norm(q)
        k, _ = self.k_norm(k)
        q = q.view(*q.shape[:-2], -1)
        k = k.view(*k.shape[:-2], -1)
        return q, k
220
221
222
223
224
225
226
227

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
228
229
        if self.use_qk_norm:
            q, k = self._apply_qk_norm(q, k)
230
231
        if self.v1 or self.sliding_window:
            q, k = self.rotary_emb(positions, q, k)
232
        attn_output = self.attn(q, k, v)
233
234
235
236
237
238
239
240
        output, _ = self.o_proj(attn_output)
        return output


class CohereDecoderLayer(nn.Module):

    def __init__(self,
                 config: CohereConfig,
241
                 cache_config: Optional[CacheConfig] = None,
242
243
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
244
245
246
        super().__init__()
        self.hidden_size = config.hidden_size

247
248
        self.self_attn = CohereAttention(config,
                                         cache_config,
249
250
                                         quant_config=quant_config,
                                         prefix=f"{prefix}.self_attn")
251

252
253
254
        self.mlp = CohereMLP(config,
                             quant_config=quant_config,
                             prefix=f"{prefix}.mlp")
255
        self.input_layernorm = LayerNorm(param_shape=(config.hidden_size),
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
                                         eps=config.layer_norm_eps)

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        residual: Optional[torch.Tensor],
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # Self Attention
        residual = hidden_states
        hidden_states, residual = self.input_layernorm(hidden_states, residual)
        hidden_states_attention = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
        )
        hidden_states_mlp = self.mlp(hidden_states)
        # Add everything together
        hidden_states = residual + hidden_states_attention + hidden_states_mlp

        return hidden_states, residual


278
@support_torch_compile
279
280
class CohereModel(nn.Module):

281
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
282
        super().__init__()
283
284
285
286
287
288

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config

289
        self.config = config
290
291
292
293
        lora_vocab = (lora_config.lora_extra_vocab_size *
                      (lora_config.max_loras or 1)) if lora_config else 0
        self.vocab_size = config.vocab_size + lora_vocab
        self.org_vocab_size = config.vocab_size
294
295
        self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
                                                   config.hidden_size)
296
297
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
298
299
            lambda prefix: CohereDecoderLayer(
                config, cache_config, quant_config, prefix=prefix),
300
            prefix=f"{prefix}.layers")
301
302
        self.norm = LayerNorm(param_shape=(config.hidden_size),
                              eps=config.layer_norm_eps)
303
304
305
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(
                ["hidden_states", "residual"], config.hidden_size))
306

307
308
309
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

310
311
312
313
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
314
        intermediate_tensors: Optional[IntermediateTensors],
315
        inputs_embeds: Optional[torch.Tensor] = None,
316
317
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
318
319
320
321
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
322
323
324
325
326
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
327
        for layer in self.layers[self.start_layer:self.end_layer]:
328
329
330
331
332
            hidden_states, residual = layer(
                positions,
                hidden_states,
                residual,
            )
333
334
335
336
337
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({
                "hidden_states": hidden_states,
                "residual": residual
            })
338
339
340
341
        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states


342
class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsQuant):
343
344
345
346
347
348
349
350
351
352
353
354
355
356
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }
    # LoRA specific attributes
    embedding_modules = {"embed_tokens": "input_embeddings"}

357
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
358
        super().__init__()
359
360
361
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config
362
        self.config = config
363
364
365
        # currently all existing command R models have `tie_word_embeddings`
        # enabled
        assert config.tie_word_embeddings
366
367
368
        self.unpadded_vocab_size = config.vocab_size
        if lora_config:
            self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
369
        self.quant_config = quant_config
370
371
        self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                config.vocab_size,
372
                                                scale=config.logit_scale)
373
374
        self.model = CohereModel(vllm_config=vllm_config,
                                 prefix=maybe_prefix(prefix, "model"))
Joe Runde's avatar
Joe Runde committed
375
        self.sampler = get_sampler()
376
377
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)
378

379
380
381
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

382
383
384
385
386
    @torch.no_grad()
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
387
        intermediate_tensors: Optional[IntermediateTensors] = None,
388
        inputs_embeds: Optional[torch.Tensor] = None,
389
    ) -> Union[torch.Tensor, IntermediateTensors]:
390
        hidden_states = self.model(input_ids, positions, intermediate_tensors,
391
                                   inputs_embeds)
392
393
        return hidden_states

394
395
396
397
398
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
399
400
        is_not_lora = hasattr(self.model.embed_tokens, 'weight')
        if is_not_lora:
401
402
            logits = self.logits_processor(self.model.embed_tokens,
                                           hidden_states, sampling_metadata)
403
        else:
404
405
            logits = self.logits_processor(self.model.embed_tokens.base_layer,
                                           hidden_states, sampling_metadata)
406

407
408
409
410
411
412
413
414
415
416
        return logits

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
        next_tokens = self.sampler(logits, sampling_metadata)
        return next_tokens

417
418
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
419
420
421
422
423
424
425
426
427
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
428
        loaded_params: Set[str] = set()
429
        for name, loaded_weight in weights:
430
431
432

            if (self.quant_config is not None and
                (scale_name := self.quant_config.get_cache_scale(name))):
433
                # Loading kv cache quantization scales
434
435
436
437
438
439
440
441
442
                param = params_dict[scale_name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
                                 loaded_weight[0])
                weight_loader(param, loaded_weight)
                loaded_params.add(scale_name)
                continue

443
444
445
446
            for param_name, shard_name, shard_id in stacked_params_mapping:
                if shard_name not in name:
                    continue
                name = name.replace(shard_name, param_name)
447
448
449
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
450
451
                if is_pp_missing_parameter(name, self):
                    continue
452
453
454
455
456
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
457
458
459
460
461
462
463
                # lm_head is not used in vllm as it is tied with embed_token.
                # To prevent errors, skip loading lm_head.weight.
                if "lm_head.weight" in name:
                    continue
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
464
465
466
467
468
                # Remapping the name of FP8 kv-scale.
                name = maybe_remap_kv_scale_name(name, params_dict)
                if name is None:
                    continue

469
470
                if is_pp_missing_parameter(name, self):
                    continue
471
472
473
474
475
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
476
        return loaded_params