"tools/vscode:/vscode.git/clone" did not exist on "02f0c7b220422792f5e53de2a7d51d2d3ff2df28"
commandr.py 18.8 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# Copyright 2024 Cohere and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# This file is based on the LLama model definition file in transformers
"""PyTorch Cohere model."""
24
from typing import Iterable, Optional, Set, Tuple, Union
25
26
27
28
29

import torch
from torch import nn
from transformers import CohereConfig

30
from vllm.attention import Attention
31
from vllm.compilation.decorators import support_torch_compile
32
from vllm.config import CacheConfig, VllmConfig
33
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
34
from vllm.model_executor.layers.activation import SiluAndMul
35
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
36
37
38
                                               QKVParallelLinear,
                                               RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
39
from vllm.model_executor.layers.quantization import QuantizationConfig
40
41
42
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
    VocabParallelEmbedding)
43
from vllm.model_executor.model_loader.weight_utils import (
44
45
    default_weight_loader, maybe_remap_kv_scale_name,
    row_parallel_weight_loader)
46
from vllm.model_executor.sampling_metadata import SamplingMetadata
47
from vllm.model_executor.utils import set_weight_attrs
48
from vllm.platforms import current_platform
49
from vllm.sequence import IntermediateTensors
50

51
from .interfaces import SupportsLoRA, SupportsPP, SupportsQuant
52
from .utils import (extract_layer_index, is_pp_missing_parameter,
53
54
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
55

56

57
@torch.compile(backend=current_platform.simple_compile_backend)
58
59
60
61
62
63
64
65
66
67
68
def layer_norm_func(hidden_states, weight, variance_epsilon):
    input_dtype = hidden_states.dtype
    hidden_states = hidden_states.to(torch.float32)
    mean = hidden_states.mean(-1, keepdim=True)
    variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True)
    hidden_states = (hidden_states - mean) * torch.rsqrt(variance +
                                                         variance_epsilon)
    hidden_states = weight.to(torch.float32) * hidden_states
    return hidden_states.to(input_dtype)


69
70
class LayerNorm(nn.Module):

71
    def __init__(self, param_shape=None, eps=1e-5):
72
        super().__init__()
73
        self.weight = nn.Parameter(torch.ones(param_shape))
74
        self.variance_epsilon = eps
75
76
        set_weight_attrs(self.weight,
                         {"weight_loader": row_parallel_weight_loader})
77
78

    def forward(self, hidden_states, residuals=None):
79
80
81
        hidden_states = layer_norm_func(hidden_states, self.weight,
                                        self.variance_epsilon)
        return hidden_states, residuals
82
83
84
85
86
87
88


# Copied from transformers.models.llama.modeling_llama.LlamaMLP Llama->Cohere
class CohereMLP(nn.Module):

    def __init__(
        self,
89
        config: CohereConfig,
90
        quant_config: Optional[QuantizationConfig] = None,
91
        prefix: str = "",
92
93
94
95
96
97
98
99
100
    ):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.intermediate_size = config.intermediate_size
        self.gate_up_proj = MergedColumnParallelLinear(
            self.hidden_size,
            [self.intermediate_size] * 2,
            bias=False,
101
            quant_config=quant_config,
102
            prefix=f"{prefix}.gate_up_proj",
103
104
105
106
107
        )
        self.down_proj = RowParallelLinear(
            self.intermediate_size,
            self.hidden_size,
            bias=False,
108
            quant_config=quant_config,
109
            prefix=f"{prefix}.down_proj",
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
        )
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


class CohereAttention(nn.Module):

    def __init__(
        self,
        config: CohereConfig,
125
        cache_config: Optional[CacheConfig] = None,
126
        quant_config: Optional[QuantizationConfig] = None,
127
        prefix: str = "",
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
    ):
        super().__init__()
        tp_size = get_tensor_model_parallel_world_size()
        self.config = config
        self.attention_dropout = config.attention_dropout
        self.hidden_size = config.hidden_size
        self.total_num_heads = config.num_attention_heads
        self.num_heads = self.total_num_heads // tp_size
        self.head_dim = self.hidden_size // self.total_num_heads
        self.total_num_kv_heads = config.num_key_value_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
150
151
152
        self.max_position_embeddings = getattr(
            config, "model_max_length", None) or getattr(
                config, "max_position_embeddings", 8192)
153
154
        self.rope_theta = config.rope_theta
        self.rope_scaling = getattr(config, "rope_scaling", None)
155
        self.use_qk_norm = getattr(config, "use_qk_norm", False)
156
157
158
159
160
161
        self.qkv_proj = QKVParallelLinear(
            self.hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=False,
162
            quant_config=quant_config,
163
            prefix=f"{prefix}.qkv_proj",
164
165
166
167
168
        )
        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            self.hidden_size,
            bias=False,
169
            quant_config=quant_config,
170
            prefix=f"{prefix}.o_proj",
171
172
173
174
175
176
177
178
179
        )
        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=self.max_position_embeddings,
            base=self.rope_theta,
            rope_scaling=self.rope_scaling,
            is_neox_style=False,
        )
180

181
182
183
184
185
        # Model v2 has interleaved sliding windows, v1 does not
        interleaved_sliding_window = getattr(config,
                                             "interleaved_sliding_window",
                                             None)
        self.v1 = interleaved_sliding_window is None
186
187
188
189
190
191

        layer_idx = extract_layer_index(prefix)
        layer_has_sliding_window = (
            getattr(config, "sliding_window_pattern", False)
            and (layer_idx + 1) % self.config.sliding_window_pattern != 0)

192
        self.sliding_window = (interleaved_sliding_window
193
194
                               if layer_has_sliding_window else None)

195
196
197
198
199
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              self.scaling,
                              num_kv_heads=self.num_kv_heads,
                              cache_config=cache_config,
200
                              quant_config=quant_config,
201
                              per_layer_sliding_window=self.sliding_window,
202
                              prefix=f"{prefix}.attn")
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
        if self.use_qk_norm:
            self.q_norm = LayerNorm(param_shape=(self.num_heads,
                                                 self.head_dim),
                                    eps=config.layer_norm_eps)
            self.k_norm = LayerNorm(param_shape=(self.num_kv_heads,
                                                 self.head_dim),
                                    eps=config.layer_norm_eps)

    def _apply_qk_norm(self, q, k):
        q = q.view(*q.shape[:-1], -1, self.head_dim)
        k = k.view(*k.shape[:-1], -1, self.head_dim)
        q, _ = self.q_norm(q)
        k, _ = self.k_norm(k)
        q = q.view(*q.shape[:-2], -1)
        k = k.view(*k.shape[:-2], -1)
        return q, k
219
220
221
222
223
224
225
226

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
227
228
        if self.use_qk_norm:
            q, k = self._apply_qk_norm(q, k)
229
230
        if self.v1 or self.sliding_window:
            q, k = self.rotary_emb(positions, q, k)
231
        attn_output = self.attn(q, k, v)
232
233
234
235
236
237
238
239
        output, _ = self.o_proj(attn_output)
        return output


class CohereDecoderLayer(nn.Module):

    def __init__(self,
                 config: CohereConfig,
240
                 cache_config: Optional[CacheConfig] = None,
241
242
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
243
244
245
        super().__init__()
        self.hidden_size = config.hidden_size

246
247
        self.self_attn = CohereAttention(config,
                                         cache_config,
248
249
                                         quant_config=quant_config,
                                         prefix=f"{prefix}.self_attn")
250

251
252
253
        self.mlp = CohereMLP(config,
                             quant_config=quant_config,
                             prefix=f"{prefix}.mlp")
254
        self.input_layernorm = LayerNorm(param_shape=(config.hidden_size),
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
                                         eps=config.layer_norm_eps)

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        residual: Optional[torch.Tensor],
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # Self Attention
        residual = hidden_states
        hidden_states, residual = self.input_layernorm(hidden_states, residual)
        hidden_states_attention = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
        )
        hidden_states_mlp = self.mlp(hidden_states)
        # Add everything together
        hidden_states = residual + hidden_states_attention + hidden_states_mlp

        return hidden_states, residual


277
@support_torch_compile
278
279
class CohereModel(nn.Module):

280
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
281
        super().__init__()
282
283
284
285
286
287

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config

288
        self.config = config
289
290
291
292
        lora_vocab = (lora_config.lora_extra_vocab_size *
                      (lora_config.max_loras or 1)) if lora_config else 0
        self.vocab_size = config.vocab_size + lora_vocab
        self.org_vocab_size = config.vocab_size
293
294
        self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
                                                   config.hidden_size)
295
296
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
297
298
            lambda prefix: CohereDecoderLayer(
                config, cache_config, quant_config, prefix=prefix),
299
            prefix=f"{prefix}.layers")
300
301
        self.norm = LayerNorm(param_shape=(config.hidden_size),
                              eps=config.layer_norm_eps)
302
303
304
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(
                ["hidden_states", "residual"], config.hidden_size))
305

306
307
308
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

309
310
311
312
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
313
        intermediate_tensors: Optional[IntermediateTensors],
314
        inputs_embeds: Optional[torch.Tensor] = None,
315
316
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
317
318
319
320
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
321
322
323
324
325
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
326
        for layer in self.layers[self.start_layer:self.end_layer]:
327
328
329
330
331
            hidden_states, residual = layer(
                positions,
                hidden_states,
                residual,
            )
332
333
334
335
336
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({
                "hidden_states": hidden_states,
                "residual": residual
            })
337
338
339
340
        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states


341
class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsQuant):
342
343
344
345
346
347
348
349
350
351
352
353
354
355
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }
    # LoRA specific attributes
    embedding_modules = {"embed_tokens": "input_embeddings"}

356
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
357
        super().__init__()
358
359
360
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config
361
        self.config = config
362
363
364
        # currently all existing command R models have `tie_word_embeddings`
        # enabled
        assert config.tie_word_embeddings
365
366
367
        self.unpadded_vocab_size = config.vocab_size
        if lora_config:
            self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
368
        self.quant_config = quant_config
369
370
        self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                config.vocab_size,
371
                                                scale=config.logit_scale)
372
373
        self.model = CohereModel(vllm_config=vllm_config,
                                 prefix=maybe_prefix(prefix, "model"))
374
375
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)
376

377
378
379
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

380
381
382
383
384
    @torch.no_grad()
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
385
        intermediate_tensors: Optional[IntermediateTensors] = None,
386
        inputs_embeds: Optional[torch.Tensor] = None,
387
    ) -> Union[torch.Tensor, IntermediateTensors]:
388
        hidden_states = self.model(input_ids, positions, intermediate_tensors,
389
                                   inputs_embeds)
390
391
        return hidden_states

392
393
394
395
396
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
397
398
        is_not_lora = hasattr(self.model.embed_tokens, 'weight')
        if is_not_lora:
399
400
            logits = self.logits_processor(self.model.embed_tokens,
                                           hidden_states, sampling_metadata)
401
        else:
402
403
            logits = self.logits_processor(self.model.embed_tokens.base_layer,
                                           hidden_states, sampling_metadata)
404

405
406
        return logits

407
408
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
409
410
411
412
413
414
415
416
417
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
418
        loaded_params: Set[str] = set()
419
        for name, loaded_weight in weights:
420

421
422
423
424
            # Skip loading rotary embeddings since vLLM has its own
            if "rotary_emb.inv_freq" in name:
                continue

425
426
            if (self.quant_config is not None and
                (scale_name := self.quant_config.get_cache_scale(name))):
427
                # Loading kv cache quantization scales
428
429
430
431
432
433
434
435
436
                param = params_dict[scale_name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
                                 loaded_weight[0])
                weight_loader(param, loaded_weight)
                loaded_params.add(scale_name)
                continue

437
438
439
440
            for param_name, shard_name, shard_id in stacked_params_mapping:
                if shard_name not in name:
                    continue
                name = name.replace(shard_name, param_name)
441
442
443
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
444
445
                if is_pp_missing_parameter(name, self):
                    continue
446
447
448
449
450
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
451
452
453
454
455
456
457
                # lm_head is not used in vllm as it is tied with embed_token.
                # To prevent errors, skip loading lm_head.weight.
                if "lm_head.weight" in name:
                    continue
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
458
459
460
461
462
                # Remapping the name of FP8 kv-scale.
                name = maybe_remap_kv_scale_name(name, params_dict)
                if name is None:
                    continue

463
464
                if is_pp_missing_parameter(name, self):
                    continue
465
466
467
468
469
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
470
        return loaded_params