decode.py 991 Bytes
Newer Older
1
2
3
4
5
6
7
8
# Copyright (c) OpenMMLab. All rights reserved.
import os
import os.path as osp

import fire
import torch

from lmdeploy import turbomind as tm
9
from lmdeploy.tokenizer import Tokenizer
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40

os.environ['TM_LOG_LEVEL'] = 'ERROR'


def main(model_path, inputs):
    """An example to perform model inference through the command line
    interface.

    Args:
        model_path (str): the path of the deployed model
        inputs (str): the path of text file contatin input text lines
    """
    tokenizer_model_path = osp.join(model_path, 'triton_models', 'tokenizer')
    tokenizer = Tokenizer(tokenizer_model_path)
    tm_model = tm.TurboMind(model_path, eos_id=tokenizer.eos_token_id)
    generator = tm_model.create_instance()

    with open(inputs, 'r') as f:
        lines = f.readlines()

    input_ids = [tokenizer.encode(x) for x in lines]

    logits = generator.decode(input_ids)

    top_1 = torch.argmax(logits, -1)

    print(top_1)


if __name__ == '__main__':
    fire.Fire(main)