"docs/vscode:/vscode.git/clone" did not exist on "5ed704ec8c4e68f1bc846ab4e3c9e355585d62da"
commandr.py 19.4 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# Copyright 2024 Cohere and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# This file is based on the LLama model definition file in transformers
"""PyTorch Cohere model."""
24
from typing import Iterable, List, Optional, Set, Tuple, Union
25
26
27
28
29
30
31

import torch
import torch.utils.checkpoint
from torch import nn
from transformers import CohereConfig

from vllm.attention import Attention, AttentionMetadata
32
from vllm.compilation.decorators import support_torch_compile
33
from vllm.config import CacheConfig, VllmConfig
34
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
35
from vllm.model_executor.layers.activation import SiluAndMul
36
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
37
38
39
                                               QKVParallelLinear,
                                               RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
40
from vllm.model_executor.layers.quantization import QuantizationConfig
41
from vllm.model_executor.layers.rotary_embedding import get_rope
Joe Runde's avatar
Joe Runde committed
42
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
43
44
from vllm.model_executor.layers.vocab_parallel_embedding import (
    VocabParallelEmbedding)
45
from vllm.model_executor.model_loader.weight_utils import (
46
47
    default_weight_loader, maybe_remap_kv_scale_name,
    row_parallel_weight_loader)
48
from vllm.model_executor.sampling_metadata import SamplingMetadata
49
from vllm.model_executor.utils import set_weight_attrs
50
from vllm.platforms import current_platform
51
from vllm.sequence import IntermediateTensors
52

53
from .interfaces import SupportsLoRA, SupportsPP
54
from .utils import (extract_layer_index, is_pp_missing_parameter,
55
56
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
57

58

59
@torch.compile(backend=current_platform.simple_compile_backend)
60
61
62
63
64
65
66
67
68
69
70
def layer_norm_func(hidden_states, weight, variance_epsilon):
    input_dtype = hidden_states.dtype
    hidden_states = hidden_states.to(torch.float32)
    mean = hidden_states.mean(-1, keepdim=True)
    variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True)
    hidden_states = (hidden_states - mean) * torch.rsqrt(variance +
                                                         variance_epsilon)
    hidden_states = weight.to(torch.float32) * hidden_states
    return hidden_states.to(input_dtype)


71
72
class LayerNorm(nn.Module):

73
    def __init__(self, param_shape=None, eps=1e-5):
74
        super().__init__()
75
        self.weight = nn.Parameter(torch.ones(param_shape))
76
        self.variance_epsilon = eps
77
78
        set_weight_attrs(self.weight,
                         {"weight_loader": row_parallel_weight_loader})
79
80

    def forward(self, hidden_states, residuals=None):
81
82
83
        hidden_states = layer_norm_func(hidden_states, self.weight,
                                        self.variance_epsilon)
        return hidden_states, residuals
84
85
86
87
88
89
90


# Copied from transformers.models.llama.modeling_llama.LlamaMLP Llama->Cohere
class CohereMLP(nn.Module):

    def __init__(
        self,
91
        config: CohereConfig,
92
        quant_config: Optional[QuantizationConfig] = None,
93
94
95
96
97
98
99
100
101
    ):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.intermediate_size = config.intermediate_size
        self.gate_up_proj = MergedColumnParallelLinear(
            self.hidden_size,
            [self.intermediate_size] * 2,
            bias=False,
102
            quant_config=quant_config,
103
104
105
106
107
        )
        self.down_proj = RowParallelLinear(
            self.intermediate_size,
            self.hidden_size,
            bias=False,
108
            quant_config=quant_config,
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
        )
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


class CohereAttention(nn.Module):

    def __init__(
        self,
        config: CohereConfig,
124
        cache_config: Optional[CacheConfig] = None,
125
        quant_config: Optional[QuantizationConfig] = None,
126
        prefix: str = "",
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
    ):
        super().__init__()
        tp_size = get_tensor_model_parallel_world_size()
        self.config = config
        self.attention_dropout = config.attention_dropout
        self.hidden_size = config.hidden_size
        self.total_num_heads = config.num_attention_heads
        self.num_heads = self.total_num_heads // tp_size
        self.head_dim = self.hidden_size // self.total_num_heads
        self.total_num_kv_heads = config.num_key_value_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
149
150
151
        self.max_position_embeddings = getattr(
            config, "model_max_length", None) or getattr(
                config, "max_position_embeddings", 8192)
152
153
        self.rope_theta = config.rope_theta
        self.rope_scaling = getattr(config, "rope_scaling", None)
154
        self.use_qk_norm = getattr(config, "use_qk_norm", False)
155
156
157
158
159
160
        self.qkv_proj = QKVParallelLinear(
            self.hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=False,
161
            quant_config=quant_config,
162
163
164
165
166
        )
        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            self.hidden_size,
            bias=False,
167
            quant_config=quant_config,
168
169
170
171
172
173
174
175
176
        )
        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=self.max_position_embeddings,
            base=self.rope_theta,
            rope_scaling=self.rope_scaling,
            is_neox_style=False,
        )
177

178
179
180
181
182
        # Model v2 has interleaved sliding windows, v1 does not
        interleaved_sliding_window = getattr(config,
                                             "interleaved_sliding_window",
                                             None)
        self.v1 = interleaved_sliding_window is None
183
184
185
186
187
188

        layer_idx = extract_layer_index(prefix)
        layer_has_sliding_window = (
            getattr(config, "sliding_window_pattern", False)
            and (layer_idx + 1) % self.config.sliding_window_pattern != 0)

189
        self.sliding_window = (interleaved_sliding_window
190
191
                               if layer_has_sliding_window else None)

192
193
194
195
196
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              self.scaling,
                              num_kv_heads=self.num_kv_heads,
                              cache_config=cache_config,
197
                              quant_config=quant_config,
198
                              per_layer_sliding_window=self.sliding_window,
199
                              prefix=f"{prefix}.attn")
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
        if self.use_qk_norm:
            self.q_norm = LayerNorm(param_shape=(self.num_heads,
                                                 self.head_dim),
                                    eps=config.layer_norm_eps)
            self.k_norm = LayerNorm(param_shape=(self.num_kv_heads,
                                                 self.head_dim),
                                    eps=config.layer_norm_eps)

    def _apply_qk_norm(self, q, k):
        q = q.view(*q.shape[:-1], -1, self.head_dim)
        k = k.view(*k.shape[:-1], -1, self.head_dim)
        q, _ = self.q_norm(q)
        k, _ = self.k_norm(k)
        q = q.view(*q.shape[:-2], -1)
        k = k.view(*k.shape[:-2], -1)
        return q, k
216
217
218
219
220

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
Roy's avatar
Roy committed
221
        kv_cache: torch.Tensor,
222
223
224
225
        attn_metadata: AttentionMetadata,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
226
227
        if self.use_qk_norm:
            q, k = self._apply_qk_norm(q, k)
228
229
        if self.v1 or self.sliding_window:
            q, k = self.rotary_emb(positions, q, k)
230
231
232
233
234
235
236
237
238
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
        output, _ = self.o_proj(attn_output)
        return output


class CohereDecoderLayer(nn.Module):

    def __init__(self,
                 config: CohereConfig,
239
                 cache_config: Optional[CacheConfig] = None,
240
241
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
242
243
244
        super().__init__()
        self.hidden_size = config.hidden_size

245
246
        self.self_attn = CohereAttention(config,
                                         cache_config,
247
248
                                         quant_config=quant_config,
                                         prefix=f"{prefix}.self_attn")
249

250
        self.mlp = CohereMLP(config, quant_config=quant_config)
251
        self.input_layernorm = LayerNorm(param_shape=(config.hidden_size),
252
253
254
255
256
257
                                         eps=config.layer_norm_eps)

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
Roy's avatar
Roy committed
258
        kv_cache: torch.Tensor,
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
        attn_metadata: AttentionMetadata,
        residual: Optional[torch.Tensor],
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # Self Attention
        residual = hidden_states
        hidden_states, residual = self.input_layernorm(hidden_states, residual)
        hidden_states_attention = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
            attn_metadata=attn_metadata,
        )
        hidden_states_mlp = self.mlp(hidden_states)
        # Add everything together
        hidden_states = residual + hidden_states_attention + hidden_states_mlp

        return hidden_states, residual


278
@support_torch_compile
279
280
class CohereModel(nn.Module):

281
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
282
        super().__init__()
283
284
285
286
287
288

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config

289
        self.config = config
290
291
292
293
        lora_vocab = (lora_config.lora_extra_vocab_size *
                      (lora_config.max_loras or 1)) if lora_config else 0
        self.vocab_size = config.vocab_size + lora_vocab
        self.org_vocab_size = config.vocab_size
294
295
        self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
                                                   config.hidden_size)
296
297
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
298
299
            lambda prefix: CohereDecoderLayer(
                config, cache_config, quant_config, prefix=prefix),
300
            prefix=f"{prefix}.layers")
301
302
        self.norm = LayerNorm(param_shape=(config.hidden_size),
                              eps=config.layer_norm_eps)
303
304
305
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(
                ["hidden_states", "residual"], config.hidden_size))
306

307
308
309
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

310
311
312
313
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
Roy's avatar
Roy committed
314
        kv_caches: List[torch.Tensor],
315
        attn_metadata: AttentionMetadata,
316
        intermediate_tensors: Optional[IntermediateTensors],
317
        inputs_embeds: Optional[torch.Tensor] = None,
318
319
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
320
321
322
323
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
324
325
326
327
328
329
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
        for i in range(self.start_layer, self.end_layer):
330
331
332
333
            layer = self.layers[i]
            hidden_states, residual = layer(
                positions,
                hidden_states,
334
                kv_caches[i - self.start_layer],
335
336
337
                attn_metadata,
                residual,
            )
338
339
340
341
342
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({
                "hidden_states": hidden_states,
                "residual": residual
            })
343
344
345
346
        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states


347
class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
348
349
350
351
352
353
354
355
356
357
358
359
360
361
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }
    # LoRA specific attributes
    embedding_modules = {"embed_tokens": "input_embeddings"}

362
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
363
        super().__init__()
364
365
366
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config
367
        self.config = config
368
369
370
        # currently all existing command R models have `tie_word_embeddings`
        # enabled
        assert config.tie_word_embeddings
371
372
373
        self.unpadded_vocab_size = config.vocab_size
        if lora_config:
            self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
374
        self.quant_config = quant_config
375
376
        self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                config.vocab_size,
377
                                                scale=config.logit_scale)
378
379
        self.model = CohereModel(vllm_config=vllm_config,
                                 prefix=maybe_prefix(prefix, "model"))
Joe Runde's avatar
Joe Runde committed
380
        self.sampler = get_sampler()
381
382
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)
383

384
385
386
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

387
388
389
390
391
    @torch.no_grad()
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
Roy's avatar
Roy committed
392
        kv_caches: List[torch.Tensor],
393
        attn_metadata: AttentionMetadata,
394
        intermediate_tensors: Optional[IntermediateTensors] = None,
395
        inputs_embeds: Optional[torch.Tensor] = None,
396
    ) -> Union[torch.Tensor, IntermediateTensors]:
397
        hidden_states = self.model(input_ids, positions, kv_caches,
398
399
                                   attn_metadata, intermediate_tensors,
                                   inputs_embeds)
400
401
        return hidden_states

402
403
404
405
406
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
407
408
        is_not_lora = hasattr(self.model.embed_tokens, 'weight')
        if is_not_lora:
409
410
            logits = self.logits_processor(self.model.embed_tokens,
                                           hidden_states, sampling_metadata)
411
        else:
412
413
            logits = self.logits_processor(self.model.embed_tokens.base_layer,
                                           hidden_states, sampling_metadata)
414

415
416
417
418
419
420
421
422
423
424
        return logits

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
        next_tokens = self.sampler(logits, sampling_metadata)
        return next_tokens

425
426
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
427
428
429
430
431
432
433
434
435
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
436
        loaded_params: Set[str] = set()
437
        for name, loaded_weight in weights:
438
439
440

            if (self.quant_config is not None and
                (scale_name := self.quant_config.get_cache_scale(name))):
441
                # Loading kv cache quantization scales
442
443
444
445
446
447
448
449
450
                param = params_dict[scale_name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
                                 loaded_weight[0])
                weight_loader(param, loaded_weight)
                loaded_params.add(scale_name)
                continue

451
452
453
454
            for param_name, shard_name, shard_id in stacked_params_mapping:
                if shard_name not in name:
                    continue
                name = name.replace(shard_name, param_name)
455
456
457
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
458
459
                if is_pp_missing_parameter(name, self):
                    continue
460
461
462
463
464
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
465
466
467
468
469
470
471
                # lm_head is not used in vllm as it is tied with embed_token.
                # To prevent errors, skip loading lm_head.weight.
                if "lm_head.weight" in name:
                    continue
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
472
473
474
475
476
                # Remapping the name of FP8 kv-scale.
                name = maybe_remap_kv_scale_name(name, params_dict)
                if name is None:
                    continue

477
478
                if is_pp_missing_parameter(name, self):
                    continue
479
480
481
482
483
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
484
        return loaded_params