commandr.py 18.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# Copyright 2024 Cohere and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# This file is based on the LLama model definition file in transformers
"""PyTorch Cohere model."""
22
from typing import Iterable, List, Optional, Set, Tuple, Union
23
24
25
26
27
28
29

import torch
import torch.utils.checkpoint
from torch import nn
from transformers import CohereConfig

from vllm.attention import Attention, AttentionMetadata
30
from vllm.compilation.decorators import support_torch_compile
31
from vllm.config import CacheConfig, VllmConfig
32
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
33
from vllm.model_executor.layers.activation import SiluAndMul
34
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
35
36
37
                                               QKVParallelLinear,
                                               RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
38
from vllm.model_executor.layers.quantization import QuantizationConfig
39
from vllm.model_executor.layers.rotary_embedding import get_rope
Joe Runde's avatar
Joe Runde committed
40
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
41
42
from vllm.model_executor.layers.vocab_parallel_embedding import (
    VocabParallelEmbedding)
43
from vllm.model_executor.model_loader.weight_utils import (
44
45
    default_weight_loader, maybe_remap_kv_scale_name,
    row_parallel_weight_loader)
46
from vllm.model_executor.sampling_metadata import SamplingMetadata
47
from vllm.model_executor.utils import set_weight_attrs
48
from vllm.platforms import current_platform
49
from vllm.sequence import IntermediateTensors
50

51
from .interfaces import SupportsLoRA, SupportsPP
52
from .utils import (extract_layer_index, is_pp_missing_parameter,
53
54
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
55

56

57
@torch.compile(backend=current_platform.simple_compile_backend)
58
59
60
61
62
63
64
65
66
67
68
def layer_norm_func(hidden_states, weight, variance_epsilon):
    input_dtype = hidden_states.dtype
    hidden_states = hidden_states.to(torch.float32)
    mean = hidden_states.mean(-1, keepdim=True)
    variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True)
    hidden_states = (hidden_states - mean) * torch.rsqrt(variance +
                                                         variance_epsilon)
    hidden_states = weight.to(torch.float32) * hidden_states
    return hidden_states.to(input_dtype)


69
70
class LayerNorm(nn.Module):

71
    def __init__(self, param_shape=None, eps=1e-5):
72
        super().__init__()
73
        self.weight = nn.Parameter(torch.ones(param_shape))
74
        self.variance_epsilon = eps
75
76
        set_weight_attrs(self.weight,
                         {"weight_loader": row_parallel_weight_loader})
77
78

    def forward(self, hidden_states, residuals=None):
79
80
81
        hidden_states = layer_norm_func(hidden_states, self.weight,
                                        self.variance_epsilon)
        return hidden_states, residuals
82
83
84
85
86
87
88


# Copied from transformers.models.llama.modeling_llama.LlamaMLP Llama->Cohere
class CohereMLP(nn.Module):

    def __init__(
        self,
89
        config: CohereConfig,
90
        quant_config: Optional[QuantizationConfig] = None,
91
92
93
94
95
96
97
98
99
    ):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.intermediate_size = config.intermediate_size
        self.gate_up_proj = MergedColumnParallelLinear(
            self.hidden_size,
            [self.intermediate_size] * 2,
            bias=False,
100
            quant_config=quant_config,
101
102
103
104
105
        )
        self.down_proj = RowParallelLinear(
            self.intermediate_size,
            self.hidden_size,
            bias=False,
106
            quant_config=quant_config,
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
        )
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


class CohereAttention(nn.Module):

    def __init__(
        self,
        config: CohereConfig,
122
        cache_config: Optional[CacheConfig] = None,
123
        quant_config: Optional[QuantizationConfig] = None,
124
        prefix: str = "",
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
    ):
        super().__init__()
        tp_size = get_tensor_model_parallel_world_size()
        self.config = config
        self.attention_dropout = config.attention_dropout
        self.hidden_size = config.hidden_size
        self.total_num_heads = config.num_attention_heads
        self.num_heads = self.total_num_heads // tp_size
        self.head_dim = self.hidden_size // self.total_num_heads
        self.total_num_kv_heads = config.num_key_value_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
147
148
149
        self.max_position_embeddings = getattr(
            config, "model_max_length", None) or getattr(
                config, "max_position_embeddings", 8192)
150
151
        self.rope_theta = config.rope_theta
        self.rope_scaling = getattr(config, "rope_scaling", None)
152
        self.use_qk_norm = getattr(config, "use_qk_norm", False)
153
154
155
156
157
158
        self.qkv_proj = QKVParallelLinear(
            self.hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=False,
159
            quant_config=quant_config,
160
161
162
163
164
        )
        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            self.hidden_size,
            bias=False,
165
            quant_config=quant_config,
166
167
168
169
170
171
172
173
174
        )
        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=self.max_position_embeddings,
            base=self.rope_theta,
            rope_scaling=self.rope_scaling,
            is_neox_style=False,
        )
175

176
177
178
179
180
        # Model v2 has interleaved sliding windows, v1 does not
        interleaved_sliding_window = getattr(config,
                                             "interleaved_sliding_window",
                                             None)
        self.v1 = interleaved_sliding_window is None
181
182
183
184
185
186

        layer_idx = extract_layer_index(prefix)
        layer_has_sliding_window = (
            getattr(config, "sliding_window_pattern", False)
            and (layer_idx + 1) % self.config.sliding_window_pattern != 0)

187
        self.sliding_window = (interleaved_sliding_window
188
189
                               if layer_has_sliding_window else None)

190
191
192
193
194
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              self.scaling,
                              num_kv_heads=self.num_kv_heads,
                              cache_config=cache_config,
195
                              quant_config=quant_config,
196
                              per_layer_sliding_window=self.sliding_window,
197
                              prefix=f"{prefix}.attn")
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
        if self.use_qk_norm:
            self.q_norm = LayerNorm(param_shape=(self.num_heads,
                                                 self.head_dim),
                                    eps=config.layer_norm_eps)
            self.k_norm = LayerNorm(param_shape=(self.num_kv_heads,
                                                 self.head_dim),
                                    eps=config.layer_norm_eps)

    def _apply_qk_norm(self, q, k):
        q = q.view(*q.shape[:-1], -1, self.head_dim)
        k = k.view(*k.shape[:-1], -1, self.head_dim)
        q, _ = self.q_norm(q)
        k, _ = self.k_norm(k)
        q = q.view(*q.shape[:-2], -1)
        k = k.view(*k.shape[:-2], -1)
        return q, k
214
215
216
217
218

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
Roy's avatar
Roy committed
219
        kv_cache: torch.Tensor,
220
221
222
223
        attn_metadata: AttentionMetadata,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
224
225
        if self.use_qk_norm:
            q, k = self._apply_qk_norm(q, k)
226
227
        if self.v1 or self.sliding_window:
            q, k = self.rotary_emb(positions, q, k)
228
229
230
231
232
233
234
235
236
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
        output, _ = self.o_proj(attn_output)
        return output


class CohereDecoderLayer(nn.Module):

    def __init__(self,
                 config: CohereConfig,
237
                 cache_config: Optional[CacheConfig] = None,
238
239
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
240
241
242
        super().__init__()
        self.hidden_size = config.hidden_size

243
244
        self.self_attn = CohereAttention(config,
                                         cache_config,
245
246
                                         quant_config=quant_config,
                                         prefix=f"{prefix}.self_attn")
247

248
        self.mlp = CohereMLP(config, quant_config=quant_config)
249
        self.input_layernorm = LayerNorm(param_shape=(config.hidden_size),
250
251
252
253
254
255
                                         eps=config.layer_norm_eps)

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
Roy's avatar
Roy committed
256
        kv_cache: torch.Tensor,
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
        attn_metadata: AttentionMetadata,
        residual: Optional[torch.Tensor],
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # Self Attention
        residual = hidden_states
        hidden_states, residual = self.input_layernorm(hidden_states, residual)
        hidden_states_attention = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
            attn_metadata=attn_metadata,
        )
        hidden_states_mlp = self.mlp(hidden_states)
        # Add everything together
        hidden_states = residual + hidden_states_attention + hidden_states_mlp

        return hidden_states, residual


276
@support_torch_compile
277
278
class CohereModel(nn.Module):

279
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
280
        super().__init__()
281
282
283
284
285
286

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config

287
        self.config = config
288
289
290
291
        lora_vocab = (lora_config.lora_extra_vocab_size *
                      (lora_config.max_loras or 1)) if lora_config else 0
        self.vocab_size = config.vocab_size + lora_vocab
        self.org_vocab_size = config.vocab_size
292
293
        self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
                                                   config.hidden_size)
294
295
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
296
297
            lambda prefix: CohereDecoderLayer(
                config, cache_config, quant_config, prefix=prefix),
298
            prefix=f"{prefix}.layers")
299
300
        self.norm = LayerNorm(param_shape=(config.hidden_size),
                              eps=config.layer_norm_eps)
301
302
303
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(
                ["hidden_states", "residual"], config.hidden_size))
304

305
306
307
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

308
309
310
311
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
Roy's avatar
Roy committed
312
        kv_caches: List[torch.Tensor],
313
        attn_metadata: AttentionMetadata,
314
        intermediate_tensors: Optional[IntermediateTensors],
315
        inputs_embeds: Optional[torch.Tensor] = None,
316
317
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
318
319
320
321
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
322
323
324
325
326
327
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
        for i in range(self.start_layer, self.end_layer):
328
329
330
331
            layer = self.layers[i]
            hidden_states, residual = layer(
                positions,
                hidden_states,
332
                kv_caches[i - self.start_layer],
333
334
335
                attn_metadata,
                residual,
            )
336
337
338
339
340
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({
                "hidden_states": hidden_states,
                "residual": residual
            })
341
342
343
344
        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states


345
class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }
    # LoRA specific attributes
    supported_lora_modules = [
        "qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens"
    ]
    embedding_modules = {"embed_tokens": "input_embeddings"}
    embedding_padding_modules = []

364
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
365
        super().__init__()
366
367
368
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config
369
        self.config = config
370
371
372
        # currently all existing command R models have `tie_word_embeddings`
        # enabled
        assert config.tie_word_embeddings
373
374
375
        self.unpadded_vocab_size = config.vocab_size
        if lora_config:
            self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
376
        self.quant_config = quant_config
377
378
        self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                config.vocab_size,
379
                                                scale=config.logit_scale)
380
381
        self.model = CohereModel(vllm_config=vllm_config,
                                 prefix=maybe_prefix(prefix, "model"))
Joe Runde's avatar
Joe Runde committed
382
        self.sampler = get_sampler()
383
384
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)
385

386
387
388
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

389
390
391
392
393
    @torch.no_grad()
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
Roy's avatar
Roy committed
394
        kv_caches: List[torch.Tensor],
395
        attn_metadata: AttentionMetadata,
396
        intermediate_tensors: Optional[IntermediateTensors] = None,
397
        inputs_embeds: Optional[torch.Tensor] = None,
398
    ) -> Union[torch.Tensor, IntermediateTensors]:
399
        hidden_states = self.model(input_ids, positions, kv_caches,
400
401
                                   attn_metadata, intermediate_tensors,
                                   inputs_embeds)
402
403
        return hidden_states

404
405
406
407
408
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
409
410
        is_not_lora = hasattr(self.model.embed_tokens, 'weight')
        if is_not_lora:
411
412
            logits = self.logits_processor(self.model.embed_tokens,
                                           hidden_states, sampling_metadata)
413
        else:
414
415
            logits = self.logits_processor(self.model.embed_tokens.base_layer,
                                           hidden_states, sampling_metadata)
416

417
418
419
420
421
422
423
424
425
426
        return logits

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
        next_tokens = self.sampler(logits, sampling_metadata)
        return next_tokens

427
428
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
429
430
431
432
433
434
435
436
437
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
438
        loaded_params: Set[str] = set()
439
        for name, loaded_weight in weights:
440
441
442
443
            for param_name, shard_name, shard_id in stacked_params_mapping:
                if shard_name not in name:
                    continue
                name = name.replace(shard_name, param_name)
444
445
446
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
447
448
                if is_pp_missing_parameter(name, self):
                    continue
449
450
451
452
453
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
454
455
456
457
458
459
460
                # lm_head is not used in vllm as it is tied with embed_token.
                # To prevent errors, skip loading lm_head.weight.
                if "lm_head.weight" in name:
                    continue
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
461
462
463
464
465
                # Remapping the name of FP8 kv-scale.
                name = maybe_remap_kv_scale_name(name, params_dict)
                if name is None:
                    continue

466
467
                if is_pp_missing_parameter(name, self):
                    continue
468
469
470
471
472
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
473
        return loaded_params