pytorch_new_model.md 15.5 KB
Newer Older
zhouxiang's avatar
zhouxiang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
# lmdeploy.pytorch 新模型支持

lmdeploy.pytorch 被设计用来简化新模型的支持以及原型的开发,新模型的支持依赖于 patch 机制,对原模型做修改以及功能添加,以期可以最大程度上复用模型的原始实现,减少工作量。

## 模型支持

我们以 transformers 中的 llama 实现来介绍模型支持的流程

在开始之前,我们首先要了解一下模型的输入。lmdeploy.pytorch 的输入与标准 transformers 模型的输入略有不同,差异主要体现在如下方面:

1. 由于支持了 continuous batching,一个 batch 的输入 `input_ids` 会被拼接成一维的长序列,然后 `unsqueeze(0)` 来保证输入维度与 transformers 中相同。这样的输入不会影响 MLP 以及 RMSNorm 等模块的计算。
2. 由于添加了对 paged attention 的支持,`past_key_value` 不再是原来的大小,而是一组形状为 `[num_blocks, block_size, num_heads, head_dim]` 的 cache 块,num_blocks 为总 block 数量,由可用显存大小决定,block_size 为预设的块大小。这样的输入改变会影响到 LlamaModel 和 LlamaAttention 的计算,因此要对这两个模块的实现进行修改。
3. 由于上述输入的改变,模型中需要一些额外的输入来支持推理,比如 batch 中的序列起始位置和长度,kv cache 的 block table 等。这些输入并不在模块的 forward 参数列表中,我们需要维护一个上下文以获得这些输入。

上面的输入改动会影响 LlamaModel 和 LlamaAttention,首先我们来实现新的 LlamaModel,这是对原始实现的简化,我们删除了很多检查代码,以避免由于输入改变造成的断言失败,仅保留了最小程度的代码:

```python
# lmdeploy/pytorch/models/llama.py

class LlamaModel(nn.Module):
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, BaseModelOutputWithPast]:
        """Rewrite implementation of LlamaModel.forward."""
        inputs_embeds = self.embed_tokens(input_ids)
        hidden_states = inputs_embeds

        # decoder layers
        for idx, decoder_layer in enumerate(self.layers):
            past_key_value = past_key_values[idx]
            layer_outputs = decoder_layer(
                hidden_states,
                attention_mask=attention_mask,
                position_ids=position_ids,
                past_key_value=past_key_value,
                output_attentions=output_attentions,
                use_cache=use_cache,
            )
            hidden_states = layer_outputs[0]
        hidden_states = self.norm(hidden_states)

        return BaseModelOutputWithPast(
            last_hidden_state=hidden_states,
            past_key_values=past_key_values,
            hidden_states=None,
            attentions=None,
        )
```

然后是对 LlamaAttention 模块的改写。按顺序实现如下操作:

1. kqv proj
2. rotary embedding
3. 填充 kv cache
4. MHA 计算
5. o proj

continuous batching 和 kv cache 的改动对该模块的影响比较大

```python
# lmdeploy/pytorch/models/llama.py
from lmdeploy.pytorch.kernels import apply_rotary_pos_emb, fill_kv_cache, paged_attention_fwd

class LlamaAttention(nn.Module):
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor],
               Optional[Tuple[torch.Tensor]]]:
        """Rewrite of LlamaAttention.forward."""
        context = self.context.context
        history_lengths = context.history_lengths
        position_ids_1d = context.position_ids_1d
        block_offsets = context.block_offsets

        # qkv proj
        query_states = q_proj(hidden_states)
        key_states = k_proj(hidden_states)
        value_states = v_proj(hidden_states)
        query_states = query_states.view(-1, num_heads, head_dim)
        key_states = key_states.view(-1, num_kv_heads, head_dim)
        value_states = value_states.view(-1, num_kv_heads, head_dim)

        # rotary embedding
        max_seq_len = position_ids.size(-1)
        kv_seq_len = max_seq_len + max(history_lengths)
        if kv_seq_len >= self.rotary_emb.max_seq_len_cached:
            cos, sin = self.rotary_emb(value_states,
                                        seq_len=kv_seq_len + 128)
        query_states, key_states = apply_rotary_pos_emb(
            query_states,
            key_states,
            self.rotary_emb.cos_cached,
            self.rotary_emb.sin_cached,
            position_ids,
            position_ids_1d,
            q_embed=query_states,
            k_embed=key_states)

        # fill kv cache
        kv_seq_length = context.kv_seq_length
        q_seq_length = context.q_seq_length
        q_start_loc = context.q_start_loc
        fill_kv_cache(key_states,
                      value_states,
                      past_key_value[0],
                      past_key_value[1],
                      q_start_loc,
                      q_seq_length,
                      block_offsets=block_offsets,
                      history_lengths=history_lengths,
                      context=context)

        # attention
        attn_output = query_states
        block_size = past_key_value[0].size(1)
        paged_attention_fwd(
            query_states,
            past_key_value[0],
            past_key_value[1],
            attn_output,
            block_offsets,
            q_start_loc=q_start_loc,
            q_seqlens=q_seq_length,
            kv_seqlens=kv_seq_length,
            max_seqlen=max_seq_len,
        )
        hidden_size = num_heads * head_dim
        attn_output = attn_output.reshape(*hidden_states.shape[:-1], hidden_size)

        # o proj
        attn_output = o_proj(attn_output)
        return attn_output, None, past_key_value
```

上面的代码有几处值得注意的地方,首先是 context 对象。我们需要 history_lengths、block_offsets 等参数辅助运算,这些参数无法通过模型的 forward 函数传递进来。因此我们维护了一个 context 对象,把几乎所有可能用到的输入参数都保存在其中,方便在各个模块间共享。context 对象可以通过 `self.context.context` 来访问,结构可以参考 [context-结构](#context-结构)

另一个值得注意的地方就是自定义 kernel,由于输入形式的改变,原来的 LlamaAttention 实现变得不再适用,为了保证推理的速度和正确性,我们在 lmdeploy.pytorch.kernels 中实现了许多自定义的 triton kernel,上面的模块中就用到了 `apply_rotary_pos_emb``fill_kv_cache``paged_attention_fwd` ,分别负责实现 rotary embedding,填充 kv cache 还有 attention 的计算。

有了上述的两个模块后,还需要将他们注册到 `lmdeploy/pytorch/models/module_map.py` 中,进行原模块与 patch 模块的映射

```python
# lmdeploy/pytorch/models/module_map.py
MODEL_MAP.update({
    'transformers.models.llama.modeling_llama.LlamaAttention':
    'lmdeploy.pytorch.models.llama.LlamaAttention',
    'transformers.models.llama.modeling_llama.LlamaModel':
    'lmdeploy.pytorch.models.llama.LlamaModel'
})
```

完成注册后,Engine 在启动时就会将这两个模块 patch 成新的实现,完成后续的部署任务。

## Tensor 并发支持

为了支持 Tensor 并发,需要对模型的权重做切分。让我们试着为上面接入的 Llama 模型添加 TP 的支持。

Llama 中涉及到 Tensor 并发的模块是 LlamaAttention 中的 qkvo proj 和 LlamaMLP 中的 gate,up 和 down proj。其中 o_proj 和 down_proj 需要按行切分,剩下的按列切分。我们可以在对应的模块中实现 `_distribution_partition_fn` 函数:

```python
# lmdeploy/pytorch/models/llama.py
from ..dist_utils import (colwise_parallelize_linear_fn,
                          rowwise_parallelize_linear_fn)

class LlamaAttention(nn.Module):
    @classmethod
    def _distribute_partition_fn(cls, mod_name: str, mod: nn.Module,
                                 device_mesh: DeviceMesh):
        """Distribution partition callback."""
        if mod_name in ['q_proj', 'k_proj', 'v_proj']:
            colwise_parallelize_linear_fn(mod,
                                          device_mesh=device_mesh,
                                          to_local=True)
        elif mod_name in ['o_proj']:
            rowwise_parallelize_linear_fn(mod,
                                          device_mesh=device_mesh,
                                          to_local=True)

class LlamaMLP(nn.Module):
    @classmethod
    def _distribute_partition_fn(cls, mod_name: str, mod: nn.Module,
                                 device_mesh: DeviceMesh):
        """Distribution partition callback."""
        if mod_name in ['gate_proj', 'up_proj']:
            colwise_parallelize_linear_fn(mod,
                                          device_mesh=device_mesh,
                                          to_local=True)
        elif mod_name in ['down_proj']:
            rowwise_parallelize_linear_fn(mod,
                                          device_mesh=device_mesh,
                                          to_local=True)

```

`_distribute_partition_fn` 会在加载模型权重时被调用,对应的权重会被按照特定的形式分配到对应的设备中。

按照目前的方案切分后的权重,需要对 o_proj 和 down_proj 的结果进行 all_reduce 操作才能得到正确的结果。可以选择将 all_reduce 放在模型的 forward 函数中,也可以选择另一种方案,添加 `_distribute_output_fn` 函数:

```python
# lmdeploy/pytorch/models/llama.py
import torch.distributed as dist

class LlamaAttention(nn.Module):
    @classmethod
    def _distribute_output_fn(cls, outputs, device_mesh: DeviceMesh):
        """Distribution output hook."""
        dist.all_reduce(outputs[0])
        return outputs

class LlamaMLP(nn.Module):
    @classmethod
    def _distribute_output_fn(cls, outputs, device_mesh: DeviceMesh):
        """Distribution output hook."""
        dist.all_reduce(outputs)
        return outputs
```

最后别忘了将 LlamaMLP 也注册进 module_map 中

```python
# lmdeploy/pytorch/models/module_map.py
MODEL_MAP.update({
    'transformers.models.llama.modeling_llama.LlamaMLP':
    'lmdeploy.pytorch.models.llama.LlamaMLP'
})
```

这样就可以利用多卡的优势,让更大的模型部署成为可能

## 模块调试

当模型的输出不符合预期时,我们会希望调试某个特定模块以确定添加的重写是否正确。`lmdeploy.pytorch` 提供了一些工具以帮助进行精度对齐。还是以上面提到的 `LlamaAttention` 模块为例。

首先,我们通过 transformers 的 API 得到想要调试的子模块的一个实例:

```python
import torch
from transformers import AutoModelForCausalLM

# get module
model_path = 'meta-llama/Llama-2-7b-chat-hf'
dtype = torch.float16
model = AutoModelForCausalLM.from_pretrained(model_path).to(torch.float16).cuda()
self_attn = model.model.layers[0].self_attn
```

然后,使用 `ModuleIOExtractor` 工具可以生成该模块的一组输入输出

```python
from lmdeploy.pytorch.tools.make_inputs import ModuleIOExtractor

# extract module input/output
input_ids = torch.tensor([[1, 2, 3, 4, 5]]).cuda()
extractor = ModuleIOExtractor(model, self_attn)
attn_args, attn_kwargs, attn_output = extractor.extract(input_ids)
```

重写模块的输入与原模块略有不同,主要体现在三方面:

1. 模型需要一些特殊输入输出,他们以 `StepContext` 的形式传入,可以使用 `make_step_context` 生成。
2. `input_ids``hidden_states` 等数据都被 continuous 化,可以使用 `continuous_tensor` 进行处理。
3. 由于 paged caching 的需要, `past_key_value` 需要被 page 化处理。

基于以上原因,我们要对提取的输入进行加工:

```python
from lmdeploy.pytorch.tools.make_inputs import make_step_context
from lmdeploy.pytorch.tools.layout_convert import continuous_tensor

# create patched input/output
context = make_step_context(input_ids,
                            kv_cache_dtype=dtype,
                            num_key_value_heads=32)
seq_length = context.q_seq_length
attn_kwargs['hidden_states'] = continuous_tensor(
    attn_kwargs['hidden_states'],
    seq_length)
attn_kwargs['past_key_value'] = context.kv_caches[0]
```

然后就可以启动重写,并比较结果正确性了。(注意输出也要 continuous 化后进行比较)

```python
from lmdeploy.pytorch.models import patch

# patch and test
patched_self_attn = patch(self_attn, extra_args=['context'])
with torch.inference_mode():
    patched_output = patched_self_attn.patched_forward(*attn_args,
                                                       **attn_kwargs,
                                                       context=context)
torch.testing.assert_close(patched_output[0],
                            continuous_tensor(attn_output[0], seq_length))
```

可以通过上述方法调试重写模块,直到精度满足预期。

## 附录

### context 结构

```python
@dataclass
class StepContext:
    """context of Model.
    """
    inputs: ModelInputs
    block_offsets: torch.LongTensor
    position_ids: torch.LongTensor
    position_ids_1d: torch.LongTensor
    q_start_loc: torch.LongTensor
    history_lengths: torch.LongTensor
    seq_length: torch.LongTensor
    max_seq_length: int
    kv_seq_length: torch.LongTensor
    kv_caches: List
    is_decoding: bool
    world_size: int = 1
    json_config: Dict = None
    local_adapter_ids: torch.LongTensor = None
    global_adapter_ids: torch.LongTensor = None
    adapter_offsets: torch.LongTensor = None
    max_rank: int = 0
```

### FAQ

- **如何访问 patch 前的模块?**

有时我们只希望在函数前后加一个 hook 代码,不希望大段的拷贝函数,可以通过 `self.origin_mod` 访问 patch 前的模块。

- **非 transformers 官方的模型该如何注册?**

一些模型的实现代码可能是以 remote code 的形式添加的,这样的模块无法通过完整的 qualname 来定位。lmdeploy.pytorch 支持使用缩写的模块名进行注册:

```python
MODULE_MAP.update({
    'modeling_internlm.InternLMAttention':
    'lmdeploy.pytorch.models.internlm.PatchedInternLMAttention',
})
```

> \[!NOTE\]
>
> 缩写的优先级会更低,有条件的话还是鼓励使用完整的 qualname 进行注册。

- **模块出现同名但不同实现怎么处理?**

目前推荐的做法是同名就映射到同一个实现中,然后在实现内部根据模块的固有参数来判断模型该使用的类型,以 baichuan2 7b/13b 为例:

```python
class BaichuanModel(nn.Module):
    def forward(self, ...):
        if self.config.num_hidden_layers == 32:
            return forward_7b(...)
        else:
            return forward_default(...)
```

- **如果希望在推理前对模块进行初始化?**

可以实现模块的 `_update_model_fn` 函数,它会在模块的权重都加载完,完成 TP 权重切分后被调用

```python
class LlamaAttention:
    def _update_model_fn(self):
        # ADD YOUR CODE HERE
```