commandr.py 14.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# coding=utf-8
# Copyright 2024 Cohere and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# This file is based on the LLama model definition file in transformers
"""PyTorch Cohere model."""
23
from typing import Iterable, List, Optional, Tuple
24
25
26
27

import torch
import torch.utils.checkpoint
from torch import nn
28
from torch.nn.parameter import Parameter
29
30
31
from transformers import CohereConfig

from vllm.attention import Attention, AttentionMetadata
32
33
from vllm.distributed import (get_tensor_model_parallel_rank,
                              get_tensor_model_parallel_world_size)
34
from vllm.model_executor.layers.activation import SiluAndMul
35
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
36
37
38
                                               QKVParallelLinear,
                                               RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
39
40
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
41
42
43
44
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
    VocabParallelEmbedding)
45
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
46
from vllm.model_executor.sampling_metadata import SamplingMetadata
47
from vllm.model_executor.utils import set_weight_attrs
48
49
50
from vllm.sequence import SamplerOutput


51
52
53
54
55
56
57
58
59
60
61
62
@torch.compile
def layer_norm_func(hidden_states, weight, variance_epsilon):
    input_dtype = hidden_states.dtype
    hidden_states = hidden_states.to(torch.float32)
    mean = hidden_states.mean(-1, keepdim=True)
    variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True)
    hidden_states = (hidden_states - mean) * torch.rsqrt(variance +
                                                         variance_epsilon)
    hidden_states = weight.to(torch.float32) * hidden_states
    return hidden_states.to(input_dtype)


63
64
class LayerNorm(nn.Module):

65
    def __init__(self, param_shape=None, eps=1e-5):
66
        super().__init__()
67
        self.weight = nn.Parameter(torch.ones(param_shape))
68
        self.variance_epsilon = eps
69
        set_weight_attrs(self.weight, {"weight_loader": self.weight_loader})
70
71

    def forward(self, hidden_states, residuals=None):
72
73
74
        hidden_states = layer_norm_func(hidden_states, self.weight,
                                        self.variance_epsilon)
        return hidden_states, residuals
75

76
77
78
79
80
81
82
83
84
85
86
87
    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        tp_rank = get_tensor_model_parallel_rank()
        shard_dim = 0 if param.dim() != 1 else None
        param_data = param.data
        if shard_dim is not None:
            shard_size = param_data.shape[shard_dim]
            start_idx = tp_rank * shard_size
            loaded_weight = loaded_weight.narrow(shard_dim, start_idx,
                                                 shard_size)
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

88
89
90
91
92
93
94

# Copied from transformers.models.llama.modeling_llama.LlamaMLP Llama->Cohere
class CohereMLP(nn.Module):

    def __init__(
        self,
        config,
95
        quant_config: Optional[QuantizationConfig] = None,
96
97
98
99
100
101
102
103
104
    ):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.intermediate_size = config.intermediate_size
        self.gate_up_proj = MergedColumnParallelLinear(
            self.hidden_size,
            [self.intermediate_size] * 2,
            bias=False,
105
            quant_config=quant_config,
106
107
108
109
110
        )
        self.down_proj = RowParallelLinear(
            self.intermediate_size,
            self.hidden_size,
            bias=False,
111
            quant_config=quant_config,
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
        )
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


class CohereAttention(nn.Module):

    def __init__(
        self,
        config: CohereConfig,
127
        quant_config: Optional[QuantizationConfig] = None,
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
    ):
        super().__init__()
        tp_size = get_tensor_model_parallel_world_size()
        self.config = config
        self.attention_dropout = config.attention_dropout
        self.hidden_size = config.hidden_size
        self.total_num_heads = config.num_attention_heads
        self.num_heads = self.total_num_heads // tp_size
        self.head_dim = self.hidden_size // self.total_num_heads
        self.total_num_kv_heads = config.num_key_value_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
150
151
152
        self.max_position_embeddings = getattr(
            config, "model_max_length", None) or getattr(
                config, "max_position_embeddings", 8192)
153
154
        self.rope_theta = config.rope_theta
        self.rope_scaling = getattr(config, "rope_scaling", None)
155
        self.use_qk_norm = getattr(config, "use_qk_norm", False)
156
157
158
159
160
161
        self.qkv_proj = QKVParallelLinear(
            self.hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=False,
162
            quant_config=quant_config,
163
164
165
166
167
        )
        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            self.hidden_size,
            bias=False,
168
            quant_config=quant_config,
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
        )
        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=self.max_position_embeddings,
            base=self.rope_theta,
            rope_scaling=self.rope_scaling,
            is_neox_style=False,
        )
        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
        )
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
        if self.use_qk_norm:
            self.q_norm = LayerNorm(param_shape=(self.num_heads,
                                                 self.head_dim),
                                    eps=config.layer_norm_eps)
            self.k_norm = LayerNorm(param_shape=(self.num_kv_heads,
                                                 self.head_dim),
                                    eps=config.layer_norm_eps)

    def _apply_qk_norm(self, q, k):
        q = q.view(*q.shape[:-1], -1, self.head_dim)
        k = k.view(*k.shape[:-1], -1, self.head_dim)
        q, _ = self.q_norm(q)
        k, _ = self.k_norm(k)
        q = q.view(*q.shape[:-2], -1)
        k = k.view(*k.shape[:-2], -1)
        return q, k
200
201
202
203
204

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
Roy's avatar
Roy committed
205
        kv_cache: torch.Tensor,
206
207
208
209
        attn_metadata: AttentionMetadata,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
210
211
        if self.use_qk_norm:
            q, k = self._apply_qk_norm(q, k)
212
213
214
215
216
217
218
219
220
221
        q, k = self.rotary_emb(positions, q, k)
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
        output, _ = self.o_proj(attn_output)
        return output


class CohereDecoderLayer(nn.Module):

    def __init__(self,
                 config: CohereConfig,
222
                 quant_config: Optional[QuantizationConfig] = None):
223
224
225
        super().__init__()
        self.hidden_size = config.hidden_size

226
        self.self_attn = CohereAttention(config, quant_config=quant_config)
227

228
        self.mlp = CohereMLP(config, quant_config=quant_config)
229
        self.input_layernorm = LayerNorm(param_shape=(config.hidden_size),
230
231
232
233
234
235
                                         eps=config.layer_norm_eps)

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
Roy's avatar
Roy committed
236
        kv_cache: torch.Tensor,
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
        attn_metadata: AttentionMetadata,
        residual: Optional[torch.Tensor],
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # Self Attention
        residual = hidden_states
        hidden_states, residual = self.input_layernorm(hidden_states, residual)
        hidden_states_attention = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
            attn_metadata=attn_metadata,
        )
        hidden_states_mlp = self.mlp(hidden_states)
        # Add everything together
        hidden_states = residual + hidden_states_attention + hidden_states_mlp

        return hidden_states, residual


class CohereModel(nn.Module):

    def __init__(
        self,
        config: CohereConfig,
261
        quant_config: Optional[QuantizationConfig] = None,
262
263
264
265
266
267
268
    ):
        super().__init__()
        self.config = config
        self.vocab_size = config.vocab_size
        self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
                                                   config.hidden_size)
        self.layers = nn.ModuleList([
269
            CohereDecoderLayer(config, quant_config=quant_config)
270
271
            for _ in range(config.num_hidden_layers)
        ])
272
273
        self.norm = LayerNorm(param_shape=(config.hidden_size),
                              eps=config.layer_norm_eps)
274
275
276
277
278

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
Roy's avatar
Roy committed
279
        kv_caches: List[torch.Tensor],
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
        attn_metadata: AttentionMetadata,
    ) -> torch.Tensor:
        hidden_states = self.embed_tokens(input_ids)
        residual = None
        for i in range(len(self.layers)):
            layer = self.layers[i]
            hidden_states, residual = layer(
                positions,
                hidden_states,
                kv_caches[i],
                attn_metadata,
                residual,
            )
        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states


class CohereForCausalLM(nn.Module):

    def __init__(
        self,
        config: CohereConfig,
302
        quant_config: Optional[QuantizationConfig] = None,
303
304
305
    ) -> None:
        super().__init__()
        self.config = config
306
        self.quant_config = quant_config
307
308
        self.logits_processor = LogitsProcessor(config.vocab_size,
                                                scale=config.logit_scale)
309
        self.model = CohereModel(config, quant_config)
310
311
312
313
314
315
316
        self.sampler = Sampler()

    @torch.no_grad()
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
Roy's avatar
Roy committed
317
        kv_caches: List[torch.Tensor],
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
        attn_metadata: AttentionMetadata,
    ) -> torch.Tensor:
        hidden_states = self.model(input_ids, positions, kv_caches,
                                   attn_metadata)
        return hidden_states

    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
        logits = self.logits_processor(self.model.embed_tokens.weight,
                                       hidden_states, sampling_metadata)
        return logits

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
        next_tokens = self.sampler(logits, sampling_metadata)
        return next_tokens

338
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
339
340
341
342
343
344
345
346
347
348
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
        loaded_params = set()
349
        for name, loaded_weight in weights:
350
351
352
353
            for param_name, shard_name, shard_id in stacked_params_mapping:
                if shard_name not in name:
                    continue
                name = name.replace(shard_name, param_name)
354
355
356
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
357
358
359
360
361
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
362
363
364
365
366
367
368
                # lm_head is not used in vllm as it is tied with embed_token.
                # To prevent errors, skip loading lm_head.weight.
                if "lm_head.weight" in name:
                    continue
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
369
370
371
372
373
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
            loaded_params.add(name)