infer_vllm.py 4.74 KB
Newer Older
zhangwq5's avatar
all  
zhangwq5 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from vllm import LLM, EngineArgs, SamplingParams
from vllm.utils import FlexibleArgumentParser
import json


def create_parser():
    parser = FlexibleArgumentParser()
    # Add engine args
    EngineArgs.add_cli_args(parser)

    parser.set_defaults(model="Qwn3/Qwen3-30B-A3B") 
    
    # Add sampling params
    sampling_group = parser.add_argument_group("Sampling parameters")
    sampling_group.add_argument("--max-tokens", type=int, default=8192,
                                help="Maximum number of tokens to generate in a single response.")
    sampling_group.add_argument("--temperature", type=float, default=0.0,
                                help="Temperature for sampling. Higher values make output more random.")
    sampling_group.add_argument("--top-p", type=float, default=1.0,
                                help="Top-p sampling probability. Only tokens with cumulative probability below top_p are considered.")
    sampling_group.add_argument("--top-k", type=int, default=1,
                                help="Top-k sampling. -1 means no top-k.")
    
    # Add example params
    parser.add_argument("--chat-template-path", type=str,
                        help="Path to a custom chat template file (Jinja format).")

    return parser


def main(args: dict):
    # Pop arguments not used by LLM
    max_tokens = args.pop("max_tokens")
    temperature = args.pop("temperature")
    top_p = args.pop("top_p")
    top_k = args.pop("top_k")
    chat_template_path = args.pop("chat_template_path")

    # Create an LLM
    llm = LLM(**args)

    # Create sampling params object
    sampling_params = SamplingParams(
        max_tokens=max_tokens,
        temperature=temperature,
        top_p=top_p,
        top_k=top_k,
        logprobs=10
    )

    # A chat template can be optionally supplied.
    # If not, the model will use its default chat template.
    chat_template = None
    if chat_template_path is not None:
        with open(chat_template_path) as f:
            chat_template = f.read()
        print(f"Loaded custom chat template from: {chat_template_path}")


    # Define the single conversation for demonstration
    single_conversation = [
        {"role": "system", "content": "You are a helpful assistant."},
        {"role": "user", "content": "介绍一下北京."},
    ]
    
    outputs = llm.chat(single_conversation, sampling_params, use_tqdm=False, chat_template=chat_template)
    print(f"Original Input Prompt (if available):\n{single_conversation[1]['content']!r}\n")

    first_10_logprobs_to_save = []

    for output in outputs:
        prompt = output.prompt
        generated_text = output.outputs[0].text
        print(f"Generated text (full output):\n{generated_text!r}")
        print("=" * 80)
        logprobs_per_step = output.outputs[0].logprobs 

        if logprobs_per_step is None:
            print("Logprobs not returned. Check your SamplingParams.")
            continue

        print("\nLogprobs per generated token:")
        for step_idx, step_logprobs_dict in enumerate(logprobs_per_step[:10]):
            
            generated_token_info = None
            for token_id, logprob_obj in step_logprobs_dict.items():
                if logprob_obj.rank == 1:
                    generated_token_info = (token_id, logprob_obj.decoded_token)
                    break 
            
            if generated_token_info:
                token_id, token_text = generated_token_info
                print(f"  Step {step_idx}:")
                print(f"    - Generated Token: {token_id} ('{token_text}')")
            else:
                print(f"  Step {step_idx}: (Could not find rank-1 token)")
                continue

            sorted_logprobs = sorted(step_logprobs_dict.values(), key=lambda x: x.rank)
            
            print("    - Top Logprobs:")
            for logprob_obj in sorted_logprobs:
                token_id = next(tid for tid, lp in step_logprobs_dict.items() if lp is logprob_obj) #
                token_text = logprob_obj.decoded_token
                logprob_value = logprob_obj.logprob
                rank = logprob_obj.rank
                
                print(f"        - Rank {rank}: Token {token_id} ('{token_text}') -> Logprob: {logprob_value:.4f}")
                if rank == 1:
                    first_10_logprobs_to_save.append(logprob_value)


    output_filename = './Qwen3-30B-A3B_logprobs_K100AI_fp16.json'
    with open(output_filename, 'w') as f:
        json.dump(first_10_logprobs_to_save, f, indent=2) 

    print(f"成功将每个生成token的logprob写入到文件: {output_filename}")


if __name__ == "__main__":
    parser = create_parser()
    args: dict = vars(parser.parse_args())
    main(args)