"docs/references/vscode:/vscode.git/clone" did not exist on "7f028b07c4cdd28b6b874cab2f0dd747e87f385f"
offline_batch_inference_qwen_1m.py 1.99 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
"""
Usage:
python3 offline_batch_inference.py
"""

from urllib.request import urlopen

import sglang as sgl


def load_prompt() -> str:
    # Test cases with various lengths can be found at:
    #
    # https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/64k.txt
    # https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/200k.txt
    # https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/600k.txt
    # https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/1m.txt

    with urlopen(
        "https://qianwen-res.oss-cn-beijing.aliyuncs.com"
        "/Qwen2.5-1M/test-data/64k.txt",
        timeout=5,
    ) as response:
        prompt = response.read().decode("utf-8")
    return prompt


# Processing the prompt.
def process_requests(llm: sgl.Engine, prompts: list[str]) -> None:
    # Create a sampling params object.
    sampling_params = {
        "temperature": 0.7,
        "top_p": 0.8,
        "top_k": 20,
        "repetition_penalty": 1.05,
        "max_new_tokens": 256,
    }
    # Generate texts from the prompts.
    outputs = llm.generate(prompts, sampling_params)
    # Print the outputs.
    for output in outputs:
        prompt_token_ids = output["meta_info"]["prompt_tokens"]
        generated_text = output["text"]
        print(
            f"Prompt length: {prompt_token_ids}, " f"Generated text: {generated_text!r}"
        )


# Create an LLM.
def initialize_engine() -> sgl.Engine:
    llm = sgl.Engine(
        model_path="Qwen/Qwen2.5-7B-Instruct-1M",
        context_length=1048576,
        page_size=256,
        attention_backend="dual_chunk_flash_attn",
        tp_size=4,
        disable_radix_cache=True,
        enable_mixed_chunk=False,
        enable_torch_compile=False,
        chunked_prefill_size=131072,
        mem_fraction_static=0.6,
        log_level="DEBUG",
    )
    return llm


def main():
    llm = initialize_engine()
    prompt = load_prompt()
    process_requests(llm, [prompt])


if __name__ == "__main__":
    main()