glm-QA.py 2.37 KB
Newer Older
yuguo960516's avatar
glm  
yuguo960516 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
# model parallel + pipeline parallel demo

import oneflow as flow
from projects.GLM.tokenizer.glm_tokenizer import GLMChineseTokenzier
from libai.utils import distributed as dist
from projects.GLM.configs.glm_inference import cfg
from projects.GLM.modeling_glm import GLMForConditionalGeneration
from projects.GLM.utils.glm_loader import GLMLoaderHuggerFace
from omegaconf import DictConfig
import time

# 只需简单配置并行方案
parallel_config = DictConfig(
    dict(
        data_parallel_size=1,
        tensor_parallel_size=2,
        pipeline_parallel_size=2,
        pipeline_num_layers=2 * 24
    )
)
dist.setup_dist_util(parallel_config)

tokenizer = GLMChineseTokenzier.from_pretrained("glm-10b-chinese")

sbp = dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast])
placement = dist.get_layer_placement(0)

loader = GLMLoaderHuggerFace(
    GLMForConditionalGeneration, 
    cfg, 
    "glm-10b-chinese",
    embedding_dropout_prob=0,
    attention_dropout_prob=0,
    output_dropout_prob=0,
)

if dist.is_main_process():
    print("请稍等,正在加载模型中...")
model = loader.load()

question = ""
while True:
    if dist.is_main_process():
        print("输入:")
        question = input("> ")
    else:
        question = None
    question = dist.broadcast_py_object(question, src=0)
    dist.synchronize()
    # print(question)
    if question.lower() == "退出":
        break

    input_ids = tokenizer.encode(
        [
            question + " 回答: [gMASK]"
        ],
        return_tensors="of",
    )
    inputs = {"input_ids": input_ids, "attention_mask": flow.ones(input_ids.size())}
    inputs = tokenizer.build_inputs_for_generation(inputs, max_gen_length=128)
    
    if dist.is_main_process():
        print("正在生成内容...")
    # start_t = time.time()
    outputs = model.generate(
        inputs=inputs['input_ids'].to_global(sbp=sbp, placement=placement), 
        position_ids=inputs['position_ids'].to_global(sbp=sbp, placement=placement), 
        generation_attention_mask=inputs['generation_attention_mask'].to_global(sbp=sbp, placement=placement), 
        max_length=128
    )
    # end_t = time.time()
    # if dist.is_main_process():
    #     print('model.generate: %s秒' % (end_t - start_t))

    res = tokenizer.decode(outputs[0])

    if dist.is_main_process():
        print("> " + res)

if dist.is_main_process():
    print("> 再见")