commandr.py 17.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# coding=utf-8
# Copyright 2024 Cohere and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# This file is based on the LLama model definition file in transformers
"""PyTorch Cohere model."""
23
from typing import Iterable, List, Optional, Set, Tuple, Union
24
25
26
27
28
29
30

import torch
import torch.utils.checkpoint
from torch import nn
from transformers import CohereConfig

from vllm.attention import Attention, AttentionMetadata
31
from vllm.compilation.decorators import support_torch_compile
32
from vllm.config import CacheConfig, LoRAConfig
33
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
34
from vllm.model_executor.layers.activation import SiluAndMul
35
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
36
37
38
                                               QKVParallelLinear,
                                               RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
39
from vllm.model_executor.layers.quantization import QuantizationConfig
40
from vllm.model_executor.layers.rotary_embedding import get_rope
41
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
42
43
from vllm.model_executor.layers.vocab_parallel_embedding import (
    VocabParallelEmbedding)
44
from vllm.model_executor.model_loader.weight_utils import (
45
46
    default_weight_loader, maybe_remap_kv_scale_name,
    row_parallel_weight_loader)
47
from vllm.model_executor.sampling_metadata import SamplingMetadata
48
from vllm.model_executor.utils import set_weight_attrs
49
from vllm.sequence import IntermediateTensors
50

51
52
53
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (is_pp_missing_parameter,
                    make_empty_intermediate_tensors_factory, make_layers)
54

55

56
57
58
59
60
61
62
63
64
65
66
67
@torch.compile
def layer_norm_func(hidden_states, weight, variance_epsilon):
    input_dtype = hidden_states.dtype
    hidden_states = hidden_states.to(torch.float32)
    mean = hidden_states.mean(-1, keepdim=True)
    variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True)
    hidden_states = (hidden_states - mean) * torch.rsqrt(variance +
                                                         variance_epsilon)
    hidden_states = weight.to(torch.float32) * hidden_states
    return hidden_states.to(input_dtype)


68
69
class LayerNorm(nn.Module):

70
    def __init__(self, param_shape=None, eps=1e-5):
71
        super().__init__()
72
        self.weight = nn.Parameter(torch.ones(param_shape))
73
        self.variance_epsilon = eps
74
75
        set_weight_attrs(self.weight,
                         {"weight_loader": row_parallel_weight_loader})
76
77

    def forward(self, hidden_states, residuals=None):
78
79
80
        hidden_states = layer_norm_func(hidden_states, self.weight,
                                        self.variance_epsilon)
        return hidden_states, residuals
81
82
83
84
85
86
87


# Copied from transformers.models.llama.modeling_llama.LlamaMLP Llama->Cohere
class CohereMLP(nn.Module):

    def __init__(
        self,
88
        config: CohereConfig,
89
        quant_config: Optional[QuantizationConfig] = None,
90
91
92
93
94
95
96
97
98
    ):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.intermediate_size = config.intermediate_size
        self.gate_up_proj = MergedColumnParallelLinear(
            self.hidden_size,
            [self.intermediate_size] * 2,
            bias=False,
99
            quant_config=quant_config,
100
101
102
103
104
        )
        self.down_proj = RowParallelLinear(
            self.intermediate_size,
            self.hidden_size,
            bias=False,
105
            quant_config=quant_config,
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
        )
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


class CohereAttention(nn.Module):

    def __init__(
        self,
        config: CohereConfig,
121
        cache_config: Optional[CacheConfig] = None,
122
        quant_config: Optional[QuantizationConfig] = None,
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
    ):
        super().__init__()
        tp_size = get_tensor_model_parallel_world_size()
        self.config = config
        self.attention_dropout = config.attention_dropout
        self.hidden_size = config.hidden_size
        self.total_num_heads = config.num_attention_heads
        self.num_heads = self.total_num_heads // tp_size
        self.head_dim = self.hidden_size // self.total_num_heads
        self.total_num_kv_heads = config.num_key_value_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
145
146
147
        self.max_position_embeddings = getattr(
            config, "model_max_length", None) or getattr(
                config, "max_position_embeddings", 8192)
148
149
        self.rope_theta = config.rope_theta
        self.rope_scaling = getattr(config, "rope_scaling", None)
150
        self.use_qk_norm = getattr(config, "use_qk_norm", False)
151
152
153
154
155
156
        self.qkv_proj = QKVParallelLinear(
            self.hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=False,
157
            quant_config=quant_config,
158
159
160
161
162
        )
        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            self.hidden_size,
            bias=False,
163
            quant_config=quant_config,
164
165
166
167
168
169
170
171
172
        )
        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=self.max_position_embeddings,
            base=self.rope_theta,
            rope_scaling=self.rope_scaling,
            is_neox_style=False,
        )
173
174
175
176
177
178
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              self.scaling,
                              num_kv_heads=self.num_kv_heads,
                              cache_config=cache_config,
                              quant_config=quant_config)
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
        if self.use_qk_norm:
            self.q_norm = LayerNorm(param_shape=(self.num_heads,
                                                 self.head_dim),
                                    eps=config.layer_norm_eps)
            self.k_norm = LayerNorm(param_shape=(self.num_kv_heads,
                                                 self.head_dim),
                                    eps=config.layer_norm_eps)

    def _apply_qk_norm(self, q, k):
        q = q.view(*q.shape[:-1], -1, self.head_dim)
        k = k.view(*k.shape[:-1], -1, self.head_dim)
        q, _ = self.q_norm(q)
        k, _ = self.k_norm(k)
        q = q.view(*q.shape[:-2], -1)
        k = k.view(*k.shape[:-2], -1)
        return q, k
195
196
197
198
199

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
Roy's avatar
Roy committed
200
        kv_cache: torch.Tensor,
201
202
203
204
        attn_metadata: AttentionMetadata,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
205
206
        if self.use_qk_norm:
            q, k = self._apply_qk_norm(q, k)
207
208
209
210
211
212
213
214
215
216
        q, k = self.rotary_emb(positions, q, k)
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
        output, _ = self.o_proj(attn_output)
        return output


class CohereDecoderLayer(nn.Module):

    def __init__(self,
                 config: CohereConfig,
217
                 cache_config: Optional[CacheConfig] = None,
218
                 quant_config: Optional[QuantizationConfig] = None):
219
220
221
        super().__init__()
        self.hidden_size = config.hidden_size

222
223
224
        self.self_attn = CohereAttention(config,
                                         cache_config,
                                         quant_config=quant_config)
225

226
        self.mlp = CohereMLP(config, quant_config=quant_config)
227
        self.input_layernorm = LayerNorm(param_shape=(config.hidden_size),
228
229
230
231
232
233
                                         eps=config.layer_norm_eps)

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
Roy's avatar
Roy committed
234
        kv_cache: torch.Tensor,
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
        attn_metadata: AttentionMetadata,
        residual: Optional[torch.Tensor],
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # Self Attention
        residual = hidden_states
        hidden_states, residual = self.input_layernorm(hidden_states, residual)
        hidden_states_attention = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
            attn_metadata=attn_metadata,
        )
        hidden_states_mlp = self.mlp(hidden_states)
        # Add everything together
        hidden_states = residual + hidden_states_attention + hidden_states_mlp

        return hidden_states, residual


254
@support_torch_compile
255
256
257
258
259
class CohereModel(nn.Module):

    def __init__(
        self,
        config: CohereConfig,
260
        cache_config: Optional[CacheConfig] = None,
261
        quant_config: Optional[QuantizationConfig] = None,
262
        lora_config: Optional[LoRAConfig] = None,
263
        prefix: str = "",
264
265
266
    ):
        super().__init__()
        self.config = config
267
268
269
270
        lora_vocab = (lora_config.lora_extra_vocab_size *
                      (lora_config.max_loras or 1)) if lora_config else 0
        self.vocab_size = config.vocab_size + lora_vocab
        self.org_vocab_size = config.vocab_size
271
272
        self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
                                                   config.hidden_size)
273
274
275
276
277
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
            lambda prefix: CohereDecoderLayer(config, cache_config,
                                              quant_config),
            prefix=f"{prefix}.layers")
278
279
        self.norm = LayerNorm(param_shape=(config.hidden_size),
                              eps=config.layer_norm_eps)
280
281
282
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(
                ["hidden_states", "residual"], config.hidden_size))
283
284
285
286
287

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
Roy's avatar
Roy committed
288
        kv_caches: List[torch.Tensor],
289
        attn_metadata: AttentionMetadata,
290
291
292
293
294
295
296
297
298
299
        intermediate_tensors: Optional[IntermediateTensors],
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
            hidden_states = self.embed_tokens(input_ids)
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
        for i in range(self.start_layer, self.end_layer):
300
301
302
303
            layer = self.layers[i]
            hidden_states, residual = layer(
                positions,
                hidden_states,
304
                kv_caches[i - self.start_layer],
305
306
307
                attn_metadata,
                residual,
            )
308
309
310
311
312
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({
                "hidden_states": hidden_states,
                "residual": residual
            })
313
314
315
316
        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states


317
class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }
    # LoRA specific attributes
    supported_lora_modules = [
        "qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens"
    ]
    embedding_modules = {"embed_tokens": "input_embeddings"}
    embedding_padding_modules = []

336
337
338
    def __init__(
        self,
        config: CohereConfig,
339
        cache_config: Optional[CacheConfig] = None,
340
        quant_config: Optional[QuantizationConfig] = None,
341
        lora_config: Optional[LoRAConfig] = None,
342
343
344
    ) -> None:
        super().__init__()
        self.config = config
345
346
347
        # currently all existing command R models have `tie_word_embeddings`
        # enabled
        assert config.tie_word_embeddings
348
349
350
        self.unpadded_vocab_size = config.vocab_size
        if lora_config:
            self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
351
        self.quant_config = quant_config
352
353
        self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                config.vocab_size,
354
                                                scale=config.logit_scale)
355
356
357
358
        self.model = CohereModel(config,
                                 cache_config,
                                 quant_config,
                                 lora_config=lora_config)
359
        self.sampler = Sampler()
360
361
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)
362
363
364
365
366
367

    @torch.no_grad()
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
Roy's avatar
Roy committed
368
        kv_caches: List[torch.Tensor],
369
        attn_metadata: AttentionMetadata,
370
        intermediate_tensors: Optional[IntermediateTensors] = None,
371
    ) -> Union[torch.Tensor, IntermediateTensors]:
372
        hidden_states = self.model(input_ids, positions, kv_caches,
373
                                   attn_metadata, intermediate_tensors)
374
375
        return hidden_states

376
377
378
379
380
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
381
382
        is_not_lora = hasattr(self.model.embed_tokens, 'weight')
        if is_not_lora:
383
384
            logits = self.logits_processor(self.model.embed_tokens,
                                           hidden_states, sampling_metadata)
385
        else:
386
387
            logits = self.logits_processor(self.model.embed_tokens.base_layer,
                                           hidden_states, sampling_metadata)
388

389
390
391
392
393
394
395
396
397
398
        return logits

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
        next_tokens = self.sampler(logits, sampling_metadata)
        return next_tokens

399
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
400
401
402
403
404
405
406
407
408
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
409
        loaded_params: Set[str] = set()
410
        for name, loaded_weight in weights:
411
412
413
414
            for param_name, shard_name, shard_id in stacked_params_mapping:
                if shard_name not in name:
                    continue
                name = name.replace(shard_name, param_name)
415
416
417
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
418
419
                if is_pp_missing_parameter(name, self):
                    continue
420
421
422
423
424
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
425
426
427
428
429
430
431
                # lm_head is not used in vllm as it is tied with embed_token.
                # To prevent errors, skip loading lm_head.weight.
                if "lm_head.weight" in name:
                    continue
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
432
433
434
435
436
                # Remapping the name of FP8 kv-scale.
                name = maybe_remap_kv_scale_name(name, params_dict)
                if name is None:
                    continue

437
438
                if is_pp_missing_parameter(name, self):
                    continue
439
440
441
442
443
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
            loaded_params.add(name)