deploy.py 19.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# Copyright (c) OpenMMLab. All rights reserved.
import configparser
import json
import os
import os.path as osp
import re
import shutil
from pathlib import Path

import fire
import safetensors
import torch
from sentencepiece import SentencePieceProcessor

supported_formats = ['llama', 'hf']


def create_workspace(_path: str):
lvhan028's avatar
lvhan028 committed
19
20
21
22
23
24
25
    """Create a workspace.

    Args:
        _path (str): the path of the workspace
    Returns:
        bool: success or not
    """
26
27
28
29
30
31
32
33
34
35
36
37
    try:
        if osp.exists(_path):
            shutil.rmtree(_path)
        os.makedirs(_path)
        print(f'create workspace in directory {_path}')
        return True
    except Exception as e:
        print(f'create workspace in {_path} failed: {e}')
        return False


def destroy_workspace(_path: str):
lvhan028's avatar
lvhan028 committed
38
39
40
41
42
43
44
    """destroy workspace.

    Args:
        _path(str): the path of the workspace
    Returns:
        bool: success or not
    """
45
46
47
48
49
    try:
        shutil.rmtree(_path)
        print(f'destroy workspace in directory {_path}')
        return True
    except Exception as e:
50
        print(f'destroy workspace in {_path} failed: {e}')
51
52
53
54
        return False


def copy_triton_model_templates(_path: str):
lvhan028's avatar
lvhan028 committed
55
56
57
58
59
60
61
    """copy triton model templates to the specified path.

    Args:
        _path (str): the target path
    Returns:
        str: the path of the triton models
    """
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
    try:
        cur_path = osp.abspath(__file__)
        dir_path = osp.dirname(cur_path)
        triton_models_path = osp.join(dir_path, 'triton_models')
        dst_path = osp.join(_path, 'triton_models')
        shutil.copytree(triton_models_path, dst_path, symlinks=True)
        print(f'copy triton model templates from "{triton_models_path}" to '
              f'"{dst_path}" successfully')
        shutil.copy(osp.join(dir_path, 'service_docker_up.sh'), _path)
        return dst_path
    except Exception as e:
        print(f'copy triton model templates from "{triton_models_path}"'
              f' to "{dst_path}" failed: {e}')
        return None


def tokenizer_info(model_path: str):
lvhan028's avatar
lvhan028 committed
79
80
81
82
83
84
85
    """Return the vocabulary size, bos token id and eos token id.

    Args:
        model_path (str): the tokenizer model's path
    Returns:
        tuple: vocabulary size, bos token id and eos token id
    """
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
    assert os.path.isfile(model_path), model_path
    sp_model = SentencePieceProcessor(model_file=model_path)
    # BOS / EOS token IDs
    n_words = sp_model.vocab_size()
    bos_id = sp_model.bos_id()
    eos_id = sp_model.eos_id()
    return n_words, bos_id, eos_id


def export(model_name: str,
           num_layer: int,
           norm_eps: float,
           model_params: dict,
           tokenizer_path: str,
           out_dir: str,
           tp: int,
           size_per_head: int = 128):
lvhan028's avatar
lvhan028 committed
103
104
105
106
107
108
109
110
111
112
113
114
    """Export deploying information to a config file.

    Args:
        model_name (str): model's name
        num_layer (int): the number of transformer blocks
        norm_eps (float): norm epsilon
        model_params (dict): parameters of a model
        tokenizer_path (str): the tokenizer model's path
        out_dir (str): the path of the output directory
        tp (int): the number of tensor parallelism
        size_per_head (int): the dimension of each head
    """
115
116
117
118
119
120
121
122
123
    out_dir = osp.join(out_dir, 'weights')
    os.makedirs(out_dir, exist_ok=True)

    def save_bin(param: torch.Tensor, name):
        print(name, param.shape)
        if param.dtype in [torch.float, torch.bfloat16]:
            param = param.half()
        param.contiguous().numpy().tofile(osp.join(out_dir, name))

Li Zhang's avatar
Li Zhang committed
124
125
    attn_bias = False

126
127
128
129
130
131
132
    # reverse the splitting axes since the weights are transposed above
    for param_name, param_data in model_params.items():
        if param_name == 'tok_embeddings.weight':
            _vocab_size, dim = param_data.shape
            head_num = dim // size_per_head
        split_dim = None
        key, ext = param_name.split('.')[-2:]
Li Zhang's avatar
Li Zhang committed
133
134
        if key == 'w_qkv' and ext == 'bias':
            attn_bias = True
135
136
        copy = False
        if key in ['w1', 'w3', 'w_qkv']:
137
138
139
            split_dim = -1
            if key == 'w1':
                inter_size = param_data.shape[-1]
140
        elif key in ['w2', 'wo']:
Li Zhang's avatar
Li Zhang committed
141
            if ext in ['scales', 'zeros', 'bias']:
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
                copy = True
            else:
                split_dim = 0
        if split_dim is not None:
            print(f'*** splitting {param_name}, shape={param_data.shape}, '
                  f'split_dim={split_dim}')
            assert param_data.shape[split_dim] % tp == 0
            split_size = param_data.shape[split_dim] // tp
            splits = torch.split(param_data, split_size, dim=split_dim)
            for i, split in enumerate(splits):
                prefix, ext = osp.splitext(param_name)
                save_bin(split, f'{prefix}.{i}{ext}')
        elif copy:
            print(f'### copying {param_name}, shape={param_data.shape}')
            copies = [param_data] * tp
            for i, copy in enumerate(copies):
                prefix, ext = osp.splitext(param_name)
                save_bin(copy, f'{prefix}.{i}{ext}')
        else:
            save_bin(param_data, param_name)

    # export config and save it to {out_dir}/config.ini
    vocab_size, bos_id, eos_id = tokenizer_info(tokenizer_path)
165
    assert _vocab_size >= vocab_size, \
166
        f'different vocab size {_vocab_size} vs {vocab_size}'
lvhan028's avatar
lvhan028 committed
167
168
169
170
171
172
173
174
175
    cfg = dict(llama=dict(
        model_name=model_name,
        head_num=head_num,
        size_per_head=size_per_head,
        vocab_size=vocab_size,
        num_layer=num_layer,
        rotary_embedding=size_per_head,
        inter_size=inter_size,
        norm_eps=norm_eps,
176
        attn_bias=int(attn_bias),
lvhan028's avatar
lvhan028 committed
177
178
179
        start_id=bos_id,
        end_id=eos_id,
        weight_type='fp16',
180
        # parameters for turbomind
lvhan028's avatar
lvhan028 committed
181
182
        max_batch_size=32,
        max_context_token_num=4,
183
        session_len=2056,
lvhan028's avatar
lvhan028 committed
184
185
        step_length=1,
        cache_max_entry_count=48,
186
        cache_chunk_size=1,
lvhan028's avatar
lvhan028 committed
187
        use_context_fmha=1,
188
189
        quant_policy=0,
        tensor_para_size=tp))
190
191
192
193
194
195
196
197
198
199
200
201
202

    config = configparser.ConfigParser()
    for section, key_values in cfg.items():
        config[section] = key_values

    config_path = osp.join(out_dir, 'config.ini')
    with open(config_path, 'w') as f:
        config.write(f)
    return True


def deploy_llama(model_name: str, model_path: str, tokenizer_path: str,
                 triton_models_path: str, tp: int):
lvhan028's avatar
lvhan028 committed
203
204
205
206
207
208
209
210
211
212
    """Deploy a model with huggingface transformers' format.

    Args:
        model_name (str): the name of the to-be-deployed model
        model_path (str): the path of the directory where the model weight
          files are
        tokenizer_path (str): the path of the tokenizer model path
        triton_models_path (str): the path of the exported triton models
        tp (int): the number of tensor parallelism
    """
213
214
215
216
    if osp.exists(tokenizer_path):
        shutil.copy(tokenizer_path,
                    osp.join(triton_models_path, 'tokenizer/tokenizer.model'))
    else:
q.yao's avatar
q.yao committed
217
        print(f'tokenizer model {tokenizer_path} does not exist')
218
219
220
221
222
223
224
225
226
227
228
229
        return False
    # read model arguments from params.json
    try:
        params_path = osp.join(model_path, 'params.json')
        with open(params_path) as f:
            model_arg = json.load(f)
            num_layer = model_arg['n_layers']
            norm_eps = model_arg['norm_eps']
    except Exception as e:
        print(f'get "n_layers" and "norm_eps" from {params_path} failed: {e}')
        return False

230
    # convert weights from llama to turbomind format
231
232
233
234
235
236
237
238
239
240
    checkpoints = []
    for pattern in ['*.pth', '*.pt']:
        checkpoints += sorted(Path(model_path).glob(pattern))
    print(checkpoints)
    n_ckpt = len(checkpoints)
    model_params = {}

    def get_param(_name, _size):
        print(_name, _size)
        if _name not in model_params:
lvhan028's avatar
lvhan028 committed
241
242
243
            model_params[_name] = torch.zeros(_size,
                                              dtype=torch.float16,
                                              device='cpu')
244
245
246
247
248
        return model_params[_name]

    for i, ckpt_path in enumerate(checkpoints):
        ckpt = torch.load(ckpt_path, map_location='cpu')
        for param_name, param_data in ckpt.items():
Li Zhang's avatar
Li Zhang committed
249
            key, ext = param_name.split('.')[-2:]
250
251
252
            # column-parallel
            if key in ['w1', 'w3', 'wq', 'wk', 'wv', 'output']:
                size = param_data.size(0)
Li Zhang's avatar
Li Zhang committed
253
254
                if ext == 'weight':
                    param = get_param(
q.yao's avatar
q.yao committed
255
256
                        param_name,
                        [size * n_ckpt, param_data.size(1)])
Li Zhang's avatar
Li Zhang committed
257
258
259
260
                    param.data[size * i:size * (i + 1), :] = param_data
                else:  # bias
                    param = get_param(param_name, [size * n_ckpt])
                    param.data[size * i:size * (i + 1)] = param_data
261
262
263
            # row-parallel
            elif key in ['w2', 'wo', 'tok_embeddings']:
                size = param_data.size(-1)
Li Zhang's avatar
Li Zhang committed
264
265
266
267
268
269
270
271
                if ext == 'weight':
                    param = get_param(param_name,
                                      [param_data.size(0), size * n_ckpt])
                    param.data[:, size * i:size * (i + 1)] = param_data
                else:  # bias
                    param = get_param(param_name, [size])
                    param.data = param_data

272
273
274
275
276
277
            elif i == 0:
                param = get_param(param_name, param_data.size())
                param.data = param_data
        del ckpt

    for name, param in model_params.items():
278
        # transpose all weights as TurboMind is expecting column-major
279
280
281
282
283
284
        # weights: (output_dims, input_dims) -> (input_dims, output_dims)
        key = name.split('.')[-2]
        if key in ['w1', 'w3', 'wq', 'wk', 'wv', 'w2', 'wo']:
            param.data = param.data.t()

    # concat qkv projection
Li Zhang's avatar
Li Zhang committed
285
286
    for t in ['weight', 'bias']:
        for i in range(1000):
q.yao's avatar
q.yao committed
287
288
289
            _qkv = [
                f'layers.{i}.attention.{k}.{t}' for k in ['wq', 'wk', 'wv']
            ]
Li Zhang's avatar
Li Zhang committed
290
291
292
293
294
295
296
297
            try:
                qkv = tuple(map(model_params.pop, _qkv))
            except KeyError:
                break
            # concat by output_dims
            qkv = torch.stack(qkv, dim=qkv[0].dim() - 1)
            print(f'layers.{i}.attention.w_qkv.{t}', qkv.shape)
            model_params[f'layers.{i}.attention.w_qkv.{t}'] = qkv
298

299
    assert i == 0 or num_layer == i, f'miss matched layers: {num_layer} vs {i}'
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320

    return export(model_name, num_layer, norm_eps, model_params,
                  tokenizer_path, triton_models_path, tp)


def permute(x: torch.Tensor):
    SIZE_PER_HEAD = 128
    if x.shape[-1] > 1:  # qweights
        dim = x.shape[-1]
        n_heads = dim // SIZE_PER_HEAD
        return x.view(-1, n_heads, 2,
                      dim // n_heads // 2).transpose(2, 3).reshape(-1, dim)
    else:  # scales, zeros
        dim = x.shape[0]
        n_heads = dim // SIZE_PER_HEAD
        return x.view(n_heads, 2, dim // n_heads // 2,
                      1).transpose(1, 2).reshape(dim, 1)


def deploy_hf(model_name: str, model_path: str, tokenizer_path: str,
              triton_models_path: str, tp: int):
lvhan028's avatar
lvhan028 committed
321
322
323
324
325
326
327
328
329
330
    """Deploy a model with huggingface transformers' format.

    Args:
        model_name (str): the name of the to-be-deployed model
        model_path (str): the path of the directory where the model weight
          files are
        tokenizer_path (str): the path of the tokenizer model path
        triton_models_path (str): the path of the exported triton models
        tp (int): the number of tensor parallelism
    """
331
332
333
334
335
    if tokenizer_path is None:
        tokenizer_path = osp.join(model_path, 'tokenizer.model')
    if osp.exists(tokenizer_path):
        shutil.copy(tokenizer_path,
                    osp.join(triton_models_path, 'tokenizer/tokenizer.model'))
336
337
338
339
340
        for _file in os.listdir(model_path):
            if _file.endswith('.json') or _file.endswith('.py'):
                json_path = osp.join(model_path, _file)
                shutil.copy(json_path,
                            osp.join(triton_models_path, 'tokenizer', _file))
341
    else:
q.yao's avatar
q.yao committed
342
        print(f'tokenizer model {tokenizer_path} does not exist')
343
344
345
346
347
348
349
350
351
352
353
354
355
356
        exit(-1)

    # read model arguments from params.json
    try:
        params_path = osp.join(model_path, 'config.json')
        with open(params_path) as f:
            model_arg = json.load(f)
            num_layer = model_arg['num_hidden_layers']
            norm_eps = model_arg['rms_norm_eps']
    except Exception as e:
        print(f'get "num_hidden_layers" and "rms_norm_eps" from '
              f'{params_path} failed: {e}')
        return False

357
    # convert weights from hf to turbomind
358
359
360
    model_params = {}

    _qweight = 'weight'
Li Zhang's avatar
Li Zhang committed
361
    _suffixes = [_qweight, 'bias']
362
363
364

    _files = [file for file in os.listdir(model_path) if file.endswith('.bin')]
    _files = sorted(_files)
365
    print(_files)
366
367
368
369
370
371
372

    _params = {}
    for _file in _files:
        _tmp = torch.load(osp.join(model_path, _file), map_location='cpu')
        _params.update(_tmp)

    def get_tensor(name):
lvhan028's avatar
lvhan028 committed
373
        """return tensor according its name."""
374
375
        return _params[name]

Li Zhang's avatar
Li Zhang committed
376
    def get_tensor_transposed(name: str):
lvhan028's avatar
lvhan028 committed
377
        """return a transposed tensor according its name."""
378
        if name not in _params and name.find('bias'):
Li Zhang's avatar
Li Zhang committed
379
            return None
380
        return _params[name].t()
381

382
383
384
    w_pack = False
    if 'model.layers.0.self_attn.W_pack.weight' in _params:
        w_pack = True
385
386
387
388
389

    for i in range(1000):
        try:
            # attention weights
            for suffix in _suffixes:
390
                if w_pack:
391
392
393
394
                    _qkvo = [
                        f'model.layers.{i}.self_attn.{t}'
                        for t in ['W_pack', 'o_proj']
                    ]
395
                    qkv, o = map(get_tensor_transposed,
396
                                 map(('{}.' + suffix).format, _qkvo))
397
398
399
400
401
402
403
404
405
406

                    if qkv is None:
                        continue
                    _shape = qkv.shape[1] // 3
                    _qkv = torch.split(qkv, [_shape, _shape, _shape], dim=1)
                    q = _qkv[0]
                    k = _qkv[1]
                    v = _qkv[2]

                else:
407
408
409
                    _qkvo = [
                        f'model.layers.{i}.self_attn.{t}_proj' for t in 'qkvo'
                    ]
410
                    q, k, v, o = map(get_tensor_transposed,
411
                                     map(('{}.' + suffix).format, _qkvo))
Li Zhang's avatar
Li Zhang committed
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
                if q is None:
                    continue
                # q, k has different layout for fb & hf, convert to fb's
                # layout
                q = permute(q)
                k = permute(k)
                if suffix == _qweight:  # weight, qweight
                    # insert a dimension for splitting heads later
                    qkv = torch.stack((q, k, v), dim=1)
                else:  # scales, zeros, bias
                    qkv = torch.stack((q.squeeze(), k.squeeze(), v.squeeze()),
                                      dim=0).squeeze(dim=-1)
                    print(suffix, qkv.shape)
                for k, v in [('w_qkv', qkv), ('wo', o)]:
                    model_params[f'layers.{i}.attention.{k}.{suffix}'] = v
427
428
429
430
431
432
433
434
            # ffn weights
            _w123 = [
                f'model.layers.{i}.mlp.{t}_proj'
                for t in ['gate', 'down', 'up']
            ]
            for suffix in _suffixes:
                w1, w2, w3 = map(get_tensor_transposed,
                                 map(('{}.' + suffix).format, _w123))
Li Zhang's avatar
Li Zhang committed
435
436
437
438
439
440
                if w1 is None:
                    continue
                if suffix in ['scales', 'zeros', 'bias']:
                    w1, w2, w3 = map(lambda x: x.squeeze(dim=-1), [w1, w2, w3])
                for k, v in [('w1', w1), ('w2', w2), ('w3', w3)]:
                    model_params[f'layers.{i}.feed_forward.{k}.{suffix}'] = v
441
442
443
444
445
446
447
448
449
450
            other = [('attention_norm.weight', 'input_layernorm.weight'),
                     ('ffn_norm.weight', 'post_attention_layernorm.weight')]
            for ft, hf in other:
                model_params[f'layers.{i}.' +
                             ft] = get_tensor(f'model.layers.{i}.' + hf)
        except safetensors.SafetensorError:
            break
        except KeyError:
            break

Li Zhang's avatar
Li Zhang committed
451
    assert num_layer == i, f'miss matched layers: {num_layer} vs {i}'
452
453
454
455
456
457
458

    other = [('tok_embeddings.weight', 'model.embed_tokens.weight'),
             ('norm.weight', 'model.norm.weight'),
             ('output.weight', 'lm_head.weight')]
    for ft, hf in other:
        model_params[ft] = get_tensor(hf)

q.yao's avatar
q.yao committed
459
460
    return export(model_name, num_layer, norm_eps, model_params,
                  tokenizer_path, triton_models_path, tp)
461
462
463


def pack_model_repository(workspace_path: str):
lvhan028's avatar
lvhan028 committed
464
465
466
467
468
    """package the model repository.

    Args:
        workspace_path: the path of workspace
    """
469
470
    model_repo_dir = osp.join(workspace_path, 'model_repository')
    os.makedirs(model_repo_dir, exist_ok=True)
lvhan028's avatar
lvhan028 committed
471
    os.symlink(src=osp.join('../triton_models/interactive'),
472
               dst=osp.join(model_repo_dir, 'turbomind'))
lvhan028's avatar
lvhan028 committed
473
474
475
476
    os.symlink(src=osp.join('../triton_models/preprocessing'),
               dst=osp.join(model_repo_dir, 'preprocessing'))
    os.symlink(src=osp.join('../triton_models/postprocessing'),
               dst=osp.join(model_repo_dir, 'postprocessing'))
477
478
479
480
481
482
483
484


def main(model_name: str,
         model_path: str,
         model_format: str,
         tokenizer_path: str = None,
         dst_path: str = './workspace',
         tp: int = 1):
485
    """deploy llama family models via turbomind.
486
487
488

    Args:
        model_name (str): the name of the to-be-deployed model, such as
489
            llama-7b, llama-13b, vicuna-7b and etc
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
        model_path (str): the directory path of the model
        model_format (str): the format of the model, fb or hf. 'fb' stands for
            META's llama format, and 'hf' means huggingface format
        tokenizer_path (str): the path of tokenizer model
        dst_path (str): the destination path that saves outputs
        tp (int): the number of GPUs used for tensor parallelism
    """

    if model_format not in supported_formats:
        print(f'the model format "{model_format}" is not supported. '
              f'The supported format are: {supported_formats}')
        exit(-1)

    if model_format == 'llama' and tokenizer_path is None:
        print('The model is llama. Its tokenizer model path should be '
              'specified')
        exit(-1)

    if not create_workspace(dst_path):
        exit(-1)

    triton_models_path = copy_triton_model_templates(dst_path)
    if triton_models_path is None:
        exit(-1)

    if model_format == 'llama':
        res = deploy_llama(model_name, model_path, tokenizer_path,
                           triton_models_path, tp)
    else:
        res = deploy_hf(model_name, model_path, tokenizer_path,
                        triton_models_path, tp)

    # update `tensor_para_size` in `triton_models/interactive/config.pbtxt`
    with open(osp.join(triton_models_path, 'interactive/config.pbtxt'),
              'a') as f:
        param = 'parameters {\n  key: "tensor_para_size"\n  value: {\n    ' \
            'string_value: ' + f'"{tp}"\n' + '  }\n}\n'
        f.write(param)
    if not res:
529
        print(f'deploy model "{model_name}" via turbomind failed')
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
        destroy_workspace(dst_path)
        exit(-1)

    # pack model repository for triton inference server
    pack_model_repository(dst_path)

    # update the value of $TP in `service_docker_up.sh`
    file_path = osp.join(dst_path, 'service_docker_up.sh')
    with open(file_path, 'r') as f:
        content = f.read()
        content = re.sub('TP=1', f'TP={tp}', content)
    with open(file_path, 'w') as f:
        f.write(content)


if __name__ == '__main__':
    fire.Fire(main)