infer.py 2.07 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
from argparse import ArgumentParser

import torch
from model.modeling_openmoe import OpenMoeForCausalLM, set_openmoe_args
from transformers import T5Tokenizer
from transformers.models.llama import LlamaConfig


def parse_args():
    parser = ArgumentParser()
    parser.add_argument("--model", default="base", type=str, help="model path", choices=["base", "8b", "test"])
    return parser.parse_args()


def inference(args):
    tokenizer = T5Tokenizer.from_pretrained("google/umt5-small")
    if args.model == "test":
18
        config = LlamaConfig.from_pretrained("hpcai-tech/openmoe-base")
19
20
21
22
23
24
        set_openmoe_args(config,
                         num_experts=config.num_experts,
                         moe_layer_interval=config.moe_layer_interval,
                         enable_kernel=True)
        model = OpenMoeForCausalLM(config)
    else:
25
        config = LlamaConfig.from_pretrained(f"hpcai-tech/openmoe-{args.model}")
26
27
28
29
        set_openmoe_args(config,
                         num_experts=config.num_experts,
                         moe_layer_interval=config.moe_layer_interval,
                         enable_kernel=False)
30
        model = OpenMoeForCausalLM.from_pretrained(f"hpcai-tech/openmoe-{args.model}", config=config)
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
    model = model.eval().bfloat16()
    model = model.to(torch.cuda.current_device())

    input_str = """```
y = list(map(int, ['1', 'hello', '2']))
```
What error does this program produce?
ValueError: invalid literal for int() with base 10: 'hello'

```
sum = 0
for i in range(100):
        sum += i
```
What is the value of sum immediately after the 10th time line 3 is executed?"""

    # print("model config: ", model.config)
    input_ids = tokenizer("<pad>" + input_str, return_tensors="pt", add_special_tokens=False)
    input_ids = input_ids.input_ids.to(torch.cuda.current_device())
    generation_output = model.generate(input_ids, use_cache=True, do_sample=True, max_new_tokens=64)
    out = tokenizer.decode(generation_output[0], skip_special_tokens=False)
    print(f"output: \n{out}\n")


if __name__ == "__main__":
    args = parse_args()
    inference(args)