trans_batch_demo.py 2.98 KB
Newer Older
Rayyyyy's avatar
Rayyyyy committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
"""

Here is an example of using batch request GLM-4-0414 Models and glm-4-9b-chat-hf models with the transformers library.,
here you need to build the conversation format yourself and then call the batch function to make batch requests.
Please note that in this demo, the memory consumption is significantly higher.

"""

from typing import Union

from transformers import AutoModelForCausalLM, AutoTokenizer, LogitsProcessorList


MODEL_PATH = "THUDM/GLM-4-9B-0414"

tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, device_map="auto").eval()


def process_model_outputs(inputs, outputs, tokenizer):
    responses = []
    for input_ids, output_ids in zip(inputs.input_ids, outputs):
        response = tokenizer.decode(output_ids[len(input_ids) :], skip_special_tokens=True).strip()
        responses.append(response)
    return responses


def batch(
    model,
    tokenizer,
    messages: Union[str, list[str]],
    max_input_tokens: int = 8192,
    max_new_tokens: int = 8192,
    num_beams: int = 1,
    do_sample: bool = True,
    top_p: float = 0.8,
    temperature: float = 0.8,
    logits_processor=None,
):
    if logits_processor is None:
        logits_processor = LogitsProcessorList()
    messages = [messages] if isinstance(messages, str) else messages
    batched_inputs = tokenizer(
        messages, return_tensors="pt", padding="max_length", truncation=True, max_length=max_input_tokens
    ).to(model.device)

    gen_kwargs = {
        "max_new_tokens": max_new_tokens,
        "num_beams": num_beams,
        "do_sample": do_sample,
        "top_p": top_p,
        "temperature": temperature,
        "logits_processor": logits_processor,
        "eos_token_id": model.config.eos_token_id,
    }
    batched_outputs = model.generate(**batched_inputs, **gen_kwargs)
    batched_response = process_model_outputs(batched_inputs, batched_outputs, tokenizer)
    return batched_response


if __name__ == "__main__":
    batch_message = [
        [
            {"role": "user", "content": "我的爸爸和妈妈结婚为什么不能带我去"},
            {"role": "assistant", "content": "因为他们结婚时你还没有出生"},
            {"role": "user", "content": "我刚才的提问是"},
        ],
        [{"role": "user", "content": "你好,你是谁"}],
    ]

    batch_inputs = []
    max_input_tokens = 128
    for i, messages in enumerate(batch_message):
        new_batch_input = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)[12:]
        max_input_tokens = max(max_input_tokens, len(new_batch_input))
        batch_inputs.append(new_batch_input)
    gen_kwargs = {
        "max_input_tokens": max_input_tokens,
        "max_new_tokens": 256,
        "do_sample": True,
        "top_p": 0.8,
        "temperature": 0.8,
        "num_beams": 1,
    }

    batch_responses = batch(model, tokenizer, batch_inputs, **gen_kwargs)
    for response in batch_responses:
        print("=" * 10)
        print(response)