llama.py 11.1 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
2
3
4
5
6
7
8
9
10
11
12
13
"""1D LLaMA model compatible with HuggingFace weights."""
import os
import glob
import filelock
from tqdm import tqdm
from typing import Dict, List, Optional, Tuple

import numpy as np
import torch
from torch import nn
from transformers import LlamaConfig

from cacheflow.models import InputMetadata
14
from cacheflow.models.attention import LlamaCacheFlowAttention
15
from cacheflow.models.layernorm import RMSNorm
Woosuk Kwon's avatar
Woosuk Kwon committed
16
17
18
19
20
21
22
23
24
25
26
27
from cacheflow.models.sample import Sampler
from cacheflow.parallel_utils.parallel_state import (
    get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from cacheflow.parallel_utils.tensor_parallel import (VocabParallelEmbedding,
                                                      ColumnParallelLinear,
                                                      RowParallelLinear)
from cacheflow.sequence import SequenceOutputs

KVCache = Tuple[torch.Tensor, torch.Tensor]


class LlamaMLP(nn.Module):
28

Woosuk Kwon's avatar
Woosuk Kwon committed
29
30
31
32
33
34
35
    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
    ):
        super().__init__()
36
37
38
        self.gate_up_proj = ColumnParallelLinear(hidden_size, 2 * intermediate_size,
                                                 bias=False, gather_output=False,
                                                 perform_initialization=False)
Woosuk Kwon's avatar
Woosuk Kwon committed
39
40
41
42
43
44
45
        self.down_proj = RowParallelLinear(intermediate_size, hidden_size,
                                           bias=False, input_is_parallel=True,
                                           perform_initialization=False)
        assert hidden_act == 'silu'
        self.act_fn = nn.SiLU()

    def forward(self, x):
46
47
48
49
50
        gate_up, _ = self.gate_up_proj(x)
        gate_up = gate_up.reshape(gate_up.shape[:-1] + (2, -1))
        gate, up = torch.split(gate_up, 1, dim=-2)
        gate = gate.squeeze(dim=-2).contiguous()
        up = up.squeeze(dim=-2).contiguous()
Woosuk Kwon's avatar
Woosuk Kwon committed
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
        x = self.act_fn(gate) * up
        x, _ = self.down_proj(x)
        return x


class LlamaAttention(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
    ):
        super().__init__()
        self.hidden_size = hidden_size
        tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tensor_model_parallel_world_size == 0
        self.num_heads = self.total_num_heads // tensor_model_parallel_world_size
        self.head_dim = hidden_size // self.total_num_heads
        self.scaling = self.head_dim ** -0.5

72
        self.qkv_proj = ColumnParallelLinear(
Woosuk Kwon's avatar
Woosuk Kwon committed
73
            hidden_size,
74
            3 * self.total_num_heads * self.head_dim,
Woosuk Kwon's avatar
Woosuk Kwon committed
75
76
77
78
79
80
81
82
83
84
85
            bias=False,
            gather_output=False,
            perform_initialization=False,
        )
        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
            input_is_parallel=True,
            perform_initialization=False,
        )
86
        self.attn = LlamaCacheFlowAttention(self.scaling, self.head_dim)
Woosuk Kwon's avatar
Woosuk Kwon committed
87
88
89
90
91
92
93
94
95

    def forward(
        self,
        positions: torch.LongTensor,
        hidden_states: torch.Tensor,
        kv_cache: KVCache,
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
    ) -> torch.Tensor:
96
97
98
99
100
101
        qkv, _ = self.qkv_proj(hidden_states)
        qkv = qkv.reshape(qkv.shape[:-1] + (3, -1))
        q, k, v = torch.split(qkv, 1, dim=-2)
        q = q.squeeze(dim=-2).contiguous()
        k = k.squeeze(dim=-2).contiguous()
        v = v.squeeze(dim=-2).contiguous()
102
        k_cache, v_cache = kv_cache
Woosuk Kwon's avatar
Woosuk Kwon committed
103
        attn_output = self.attn(
104
            positions, q, k, v, k_cache, v_cache, input_metadata, cache_event)
Woosuk Kwon's avatar
Woosuk Kwon committed
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
        output, _ = self.o_proj(attn_output)
        return output


class LlamaDecoderLayer(nn.Module):

    def __init__(self, config: LlamaConfig):
        super().__init__()
        self.hidden_size = config.hidden_size
        self.self_attn = LlamaAttention(
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
        )
        self.mlp = LlamaMLP(
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
        )
123
124
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
Woosuk Kwon's avatar
Woosuk Kwon committed
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164

    def forward(
        self,
        positions: torch.LongTensor,
        hidden_states: torch.Tensor,
        kv_cache: KVCache,
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
    ) -> torch.Tensor:
        # Self Attention
        residual = hidden_states
        hidden_states = self.input_layernorm(hidden_states)
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
            input_metadata=input_metadata,
            cache_event=cache_event,
        )
        hidden_states = residual + hidden_states

        # Fully Connected
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states
        return hidden_states


class LlamaModel(nn.Module):

    def __init__(self, config: LlamaConfig):
        super().__init__()
        self.config = config
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size,
                                                   perform_initialization=False)
        self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
165
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
Woosuk Kwon's avatar
Woosuk Kwon committed
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219

    def forward(
        self,
        input_ids: torch.LongTensor,
        positions: torch.LongTensor,
        kv_caches: List[KVCache],
        input_metadata: InputMetadata,
        cache_events: Optional[List[torch.cuda.Event]],
    ) -> torch.Tensor:
        hidden_states = self.embed_tokens(input_ids)
        for i in range(len(self.layers)):
            if cache_events is None:
                cache_event = None
            else:
                cache_event = cache_events[i]
            layer = self.layers[i]
            hidden_states = layer(
                positions,
                hidden_states,
                kv_caches[i],
                input_metadata,
                cache_event,
            )
        hidden_states = self.norm(hidden_states)
        return hidden_states


class LlamaForCausalLM(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.model = LlamaModel(config)
        self.lm_head = ColumnParallelLinear(config.hidden_size,
                                            config.vocab_size,
                                            bias=False,
                                            gather_output=False,
                                            perform_initialization=False)
        self.sampler = Sampler()

    def forward(
        self,
        input_ids: torch.LongTensor,
        positions: torch.LongTensor,
        kv_caches: List[KVCache],
        input_metadata: InputMetadata,
        cache_events: Optional[List[torch.cuda.Event]],
    ) -> Dict[int, SequenceOutputs]:
        hidden_states = self.model(
            input_ids, positions, kv_caches, input_metadata, cache_events)
        next_tokens = self.sampler(
            self.lm_head.weight, hidden_states, input_metadata)
        return next_tokens

    _column_parallel_weights = ["embed_tokens.weight", "lm_head.weight",
220
                                "qkv_proj.weight", "gate_proj.weight",
Woosuk Kwon's avatar
Woosuk Kwon committed
221
222
223
224
225
226
227
                                "up_proj.weight"]
    _row_parallel_weights = ["o_proj.weight", "down_proj.weight"]

    def load_weights(self, weights_path: str):
        tensor_model_parallel_rank = get_tensor_model_parallel_rank()
        state_dict = self.state_dict()
        for name, param in state_dict.items():
228
229
230
231
232
233
234
235
236
237
238
239
240
241
            if "qkv_proj" in name or "gate_up_proj" in name:
                if "qkv_proj" in name:
                    original_name = "qkv_proj"
                    weight_names = ["q_proj", "k_proj", "v_proj"]
                    shard_size = param.shape[0] // 3
                else:
                    original_name = "gate_up_proj"
                    weight_names = ["gate_proj", "up_proj"]
                    shard_size = param.shape[0] // 2
                weights_to_concat = []
                for weight_name in weight_names:
                    weight = np.load(os.path.join(
                        weights_path, name.replace(original_name, weight_name)))
                    weights_to_concat.append(weight[
Woosuk Kwon's avatar
Woosuk Kwon committed
242
                        shard_size * tensor_model_parallel_rank
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
                        :shard_size * (tensor_model_parallel_rank + 1)])
                loaded_weight = torch.from_numpy(
                    np.concatenate(weights_to_concat, axis=0))
            else:
                loaded_weight = torch.from_numpy(
                    np.load(os.path.join(weights_path, name)))
                for p in self._column_parallel_weights:
                    if p in name:
                        shard_size = param.shape[0]
                        loaded_weight = loaded_weight[
                            shard_size * tensor_model_parallel_rank
                            :shard_size * (tensor_model_parallel_rank + 1)]
                        break
                for p in self._row_parallel_weights:
                    if p in name:
                        shard_size = param.shape[1]
                        loaded_weight = loaded_weight[
                            :,
                            shard_size * tensor_model_parallel_rank
                            :shard_size * (tensor_model_parallel_rank + 1)]
                        break
Woosuk Kwon's avatar
Woosuk Kwon committed
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293

            assert param.shape == loaded_weight.shape
            param.data.copy_(loaded_weight)

    @staticmethod
    def get_weights(model_name: str, path: str):
        if not os.path.isfile(os.path.join(model_name, "config.json")):
            raise ValueError("LLaMA model's model_name has to be a path"
                             "to the huggingface model's directory.")
        path = os.path.join(model_name, f"np")
        path = os.path.abspath(os.path.expanduser(path))
        os.makedirs(path, exist_ok=True)
        lock_path = os.path.join(path, "file_lock")
        lock = filelock.FileLock(lock_path)

        with lock:
            test_weight_path = os.path.join(path, "model.embed_tokens.weight")
            if os.path.exists(test_weight_path):
                return path

            bin_files = glob.glob(os.path.join(model_name, "*.bin"))

            for bin_file in tqdm(bin_files, desc="Convert format"):
                state = torch.load(bin_file, map_location="cpu")
                for name, param in tqdm(state.items(), leave=False):
                    param_path = os.path.join(path, name)
                    with open(param_path, "wb") as f:
                        np.save(f, param.cpu().detach().numpy())

            return path