commandr.py 17.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# coding=utf-8
# Copyright 2024 Cohere and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# This file is based on the LLama model definition file in transformers
"""PyTorch Cohere model."""
23
from typing import Iterable, List, Optional, Set, Tuple, Union
24
25
26
27
28
29
30

import torch
import torch.utils.checkpoint
from torch import nn
from transformers import CohereConfig

from vllm.attention import Attention, AttentionMetadata
31
from vllm.config import CacheConfig, LoRAConfig
32
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
33
from vllm.model_executor.layers.activation import SiluAndMul
34
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
35
36
37
                                               QKVParallelLinear,
                                               RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
38
from vllm.model_executor.layers.quantization import QuantizationConfig
39
from vllm.model_executor.layers.rotary_embedding import get_rope
40
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
41
42
from vllm.model_executor.layers.vocab_parallel_embedding import (
    VocabParallelEmbedding)
43
from vllm.model_executor.model_loader.weight_utils import (
44
45
    default_weight_loader, maybe_remap_kv_scale_name,
    row_parallel_weight_loader)
46
from vllm.model_executor.sampling_metadata import SamplingMetadata
47
from vllm.model_executor.utils import set_weight_attrs
48
from vllm.sequence import IntermediateTensors
49

50
51
52
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (is_pp_missing_parameter,
                    make_empty_intermediate_tensors_factory, make_layers)
53

54

55
56
57
58
59
60
61
62
63
64
65
66
@torch.compile
def layer_norm_func(hidden_states, weight, variance_epsilon):
    input_dtype = hidden_states.dtype
    hidden_states = hidden_states.to(torch.float32)
    mean = hidden_states.mean(-1, keepdim=True)
    variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True)
    hidden_states = (hidden_states - mean) * torch.rsqrt(variance +
                                                         variance_epsilon)
    hidden_states = weight.to(torch.float32) * hidden_states
    return hidden_states.to(input_dtype)


67
68
class LayerNorm(nn.Module):

69
    def __init__(self, param_shape=None, eps=1e-5):
70
        super().__init__()
71
        self.weight = nn.Parameter(torch.ones(param_shape))
72
        self.variance_epsilon = eps
73
74
        set_weight_attrs(self.weight,
                         {"weight_loader": row_parallel_weight_loader})
75
76

    def forward(self, hidden_states, residuals=None):
77
78
79
        hidden_states = layer_norm_func(hidden_states, self.weight,
                                        self.variance_epsilon)
        return hidden_states, residuals
80
81
82
83
84
85
86


# Copied from transformers.models.llama.modeling_llama.LlamaMLP Llama->Cohere
class CohereMLP(nn.Module):

    def __init__(
        self,
87
        config: CohereConfig,
88
        quant_config: Optional[QuantizationConfig] = None,
89
90
91
92
93
94
95
96
97
    ):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.intermediate_size = config.intermediate_size
        self.gate_up_proj = MergedColumnParallelLinear(
            self.hidden_size,
            [self.intermediate_size] * 2,
            bias=False,
98
            quant_config=quant_config,
99
100
101
102
103
        )
        self.down_proj = RowParallelLinear(
            self.intermediate_size,
            self.hidden_size,
            bias=False,
104
            quant_config=quant_config,
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
        )
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


class CohereAttention(nn.Module):

    def __init__(
        self,
        config: CohereConfig,
120
        cache_config: Optional[CacheConfig] = None,
121
        quant_config: Optional[QuantizationConfig] = None,
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
    ):
        super().__init__()
        tp_size = get_tensor_model_parallel_world_size()
        self.config = config
        self.attention_dropout = config.attention_dropout
        self.hidden_size = config.hidden_size
        self.total_num_heads = config.num_attention_heads
        self.num_heads = self.total_num_heads // tp_size
        self.head_dim = self.hidden_size // self.total_num_heads
        self.total_num_kv_heads = config.num_key_value_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
144
145
146
        self.max_position_embeddings = getattr(
            config, "model_max_length", None) or getattr(
                config, "max_position_embeddings", 8192)
147
148
        self.rope_theta = config.rope_theta
        self.rope_scaling = getattr(config, "rope_scaling", None)
149
        self.use_qk_norm = getattr(config, "use_qk_norm", False)
150
151
152
153
154
155
        self.qkv_proj = QKVParallelLinear(
            self.hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=False,
156
            quant_config=quant_config,
157
158
159
160
161
        )
        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            self.hidden_size,
            bias=False,
162
            quant_config=quant_config,
163
164
165
166
167
168
169
170
171
        )
        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=self.max_position_embeddings,
            base=self.rope_theta,
            rope_scaling=self.rope_scaling,
            is_neox_style=False,
        )
172
173
174
175
176
177
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              self.scaling,
                              num_kv_heads=self.num_kv_heads,
                              cache_config=cache_config,
                              quant_config=quant_config)
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
        if self.use_qk_norm:
            self.q_norm = LayerNorm(param_shape=(self.num_heads,
                                                 self.head_dim),
                                    eps=config.layer_norm_eps)
            self.k_norm = LayerNorm(param_shape=(self.num_kv_heads,
                                                 self.head_dim),
                                    eps=config.layer_norm_eps)

    def _apply_qk_norm(self, q, k):
        q = q.view(*q.shape[:-1], -1, self.head_dim)
        k = k.view(*k.shape[:-1], -1, self.head_dim)
        q, _ = self.q_norm(q)
        k, _ = self.k_norm(k)
        q = q.view(*q.shape[:-2], -1)
        k = k.view(*k.shape[:-2], -1)
        return q, k
194
195
196
197
198

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
Roy's avatar
Roy committed
199
        kv_cache: torch.Tensor,
200
201
202
203
        attn_metadata: AttentionMetadata,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
204
205
        if self.use_qk_norm:
            q, k = self._apply_qk_norm(q, k)
206
207
208
209
210
211
212
213
214
215
        q, k = self.rotary_emb(positions, q, k)
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
        output, _ = self.o_proj(attn_output)
        return output


class CohereDecoderLayer(nn.Module):

    def __init__(self,
                 config: CohereConfig,
216
                 cache_config: Optional[CacheConfig] = None,
217
                 quant_config: Optional[QuantizationConfig] = None):
218
219
220
        super().__init__()
        self.hidden_size = config.hidden_size

221
222
223
        self.self_attn = CohereAttention(config,
                                         cache_config,
                                         quant_config=quant_config)
224

225
        self.mlp = CohereMLP(config, quant_config=quant_config)
226
        self.input_layernorm = LayerNorm(param_shape=(config.hidden_size),
227
228
229
230
231
232
                                         eps=config.layer_norm_eps)

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
Roy's avatar
Roy committed
233
        kv_cache: torch.Tensor,
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
        attn_metadata: AttentionMetadata,
        residual: Optional[torch.Tensor],
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # Self Attention
        residual = hidden_states
        hidden_states, residual = self.input_layernorm(hidden_states, residual)
        hidden_states_attention = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
            attn_metadata=attn_metadata,
        )
        hidden_states_mlp = self.mlp(hidden_states)
        # Add everything together
        hidden_states = residual + hidden_states_attention + hidden_states_mlp

        return hidden_states, residual


class CohereModel(nn.Module):

    def __init__(
        self,
        config: CohereConfig,
258
        cache_config: Optional[CacheConfig] = None,
259
        quant_config: Optional[QuantizationConfig] = None,
260
        lora_config: Optional[LoRAConfig] = None,
261
        prefix: str = "",
262
263
264
    ):
        super().__init__()
        self.config = config
265
266
267
268
        lora_vocab = (lora_config.lora_extra_vocab_size *
                      (lora_config.max_loras or 1)) if lora_config else 0
        self.vocab_size = config.vocab_size + lora_vocab
        self.org_vocab_size = config.vocab_size
269
270
        self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
                                                   config.hidden_size)
271
272
273
274
275
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
            lambda prefix: CohereDecoderLayer(config, cache_config,
                                              quant_config),
            prefix=f"{prefix}.layers")
276
277
        self.norm = LayerNorm(param_shape=(config.hidden_size),
                              eps=config.layer_norm_eps)
278
279
280
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(
                ["hidden_states", "residual"], config.hidden_size))
281
282
283
284
285

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
Roy's avatar
Roy committed
286
        kv_caches: List[torch.Tensor],
287
        attn_metadata: AttentionMetadata,
288
289
290
291
292
293
294
295
296
297
        intermediate_tensors: Optional[IntermediateTensors],
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
            hidden_states = self.embed_tokens(input_ids)
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
        for i in range(self.start_layer, self.end_layer):
298
299
300
301
            layer = self.layers[i]
            hidden_states, residual = layer(
                positions,
                hidden_states,
302
                kv_caches[i - self.start_layer],
303
304
305
                attn_metadata,
                residual,
            )
306
307
308
309
310
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({
                "hidden_states": hidden_states,
                "residual": residual
            })
311
312
313
314
        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states


315
class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }
    # LoRA specific attributes
    supported_lora_modules = [
        "qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens"
    ]
    embedding_modules = {"embed_tokens": "input_embeddings"}
    embedding_padding_modules = []

334
335
336
    def __init__(
        self,
        config: CohereConfig,
337
        cache_config: Optional[CacheConfig] = None,
338
        quant_config: Optional[QuantizationConfig] = None,
339
        lora_config: Optional[LoRAConfig] = None,
340
341
342
    ) -> None:
        super().__init__()
        self.config = config
343
344
345
        # currently all existing command R models have `tie_word_embeddings`
        # enabled
        assert config.tie_word_embeddings
346
347
348
        self.unpadded_vocab_size = config.vocab_size
        if lora_config:
            self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
349
        self.quant_config = quant_config
350
351
        self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                config.vocab_size,
352
                                                scale=config.logit_scale)
353
354
355
356
        self.model = CohereModel(config,
                                 cache_config,
                                 quant_config,
                                 lora_config=lora_config)
357
        self.sampler = Sampler()
358
359
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)
360
361
362
363
364
365

    @torch.no_grad()
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
Roy's avatar
Roy committed
366
        kv_caches: List[torch.Tensor],
367
        attn_metadata: AttentionMetadata,
368
        intermediate_tensors: Optional[IntermediateTensors] = None,
369
    ) -> Union[torch.Tensor, IntermediateTensors]:
370
        hidden_states = self.model(input_ids, positions, kv_caches,
371
                                   attn_metadata, intermediate_tensors)
372
373
        return hidden_states

374
375
376
377
378
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
379
380
        is_not_lora = hasattr(self.model.embed_tokens, 'weight')
        if is_not_lora:
381
382
            logits = self.logits_processor(self.model.embed_tokens,
                                           hidden_states, sampling_metadata)
383
        else:
384
385
            logits = self.logits_processor(self.model.embed_tokens.base_layer,
                                           hidden_states, sampling_metadata)
386

387
388
389
390
391
392
393
394
395
396
        return logits

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
        next_tokens = self.sampler(logits, sampling_metadata)
        return next_tokens

397
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
398
399
400
401
402
403
404
405
406
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
407
        loaded_params: Set[str] = set()
408
        for name, loaded_weight in weights:
409
410
411
412
            for param_name, shard_name, shard_id in stacked_params_mapping:
                if shard_name not in name:
                    continue
                name = name.replace(shard_name, param_name)
413
414
415
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
416
417
                if is_pp_missing_parameter(name, self):
                    continue
418
419
420
421
422
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
423
424
425
426
427
428
429
                # lm_head is not used in vllm as it is tied with embed_token.
                # To prevent errors, skip loading lm_head.weight.
                if "lm_head.weight" in name:
                    continue
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
430
431
432
433
434
                # Remapping the name of FP8 kv-scale.
                name = maybe_remap_kv_scale_name(name, params_dict)
                if name is None:
                    continue

435
436
                if is_pp_missing_parameter(name, self):
                    continue
437
438
439
440
441
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
            loaded_params.add(name)