# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from vllm import LLM, EngineArgs, SamplingParams from vllm.utils import FlexibleArgumentParser import json def create_parser(): parser = FlexibleArgumentParser() # Add engine args EngineArgs.add_cli_args(parser) parser.set_defaults(model="Qwn3/Qwen3-30B-A3B") # Add sampling params sampling_group = parser.add_argument_group("Sampling parameters") sampling_group.add_argument("--max-tokens", type=int, default=8192, help="Maximum number of tokens to generate in a single response.") sampling_group.add_argument("--temperature", type=float, default=0.0, help="Temperature for sampling. Higher values make output more random.") sampling_group.add_argument("--top-p", type=float, default=1.0, help="Top-p sampling probability. Only tokens with cumulative probability below top_p are considered.") sampling_group.add_argument("--top-k", type=int, default=1, help="Top-k sampling. -1 means no top-k.") # Add example params parser.add_argument("--chat-template-path", type=str, help="Path to a custom chat template file (Jinja format).") return parser def main(args: dict): # Pop arguments not used by LLM max_tokens = args.pop("max_tokens") temperature = args.pop("temperature") top_p = args.pop("top_p") top_k = args.pop("top_k") chat_template_path = args.pop("chat_template_path") # Create an LLM llm = LLM(**args) # Create sampling params object sampling_params = SamplingParams( max_tokens=max_tokens, temperature=temperature, top_p=top_p, top_k=top_k, logprobs=10 ) # A chat template can be optionally supplied. # If not, the model will use its default chat template. chat_template = None if chat_template_path is not None: with open(chat_template_path) as f: chat_template = f.read() print(f"Loaded custom chat template from: {chat_template_path}") # Define the single conversation for demonstration single_conversation = [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "介绍一下北京."}, ] outputs = llm.chat(single_conversation, sampling_params, use_tqdm=False, chat_template=chat_template) print(f"Original Input Prompt (if available):\n{single_conversation[1]['content']!r}\n") first_10_logprobs_to_save = [] for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text print(f"Generated text (full output):\n{generated_text!r}") print("=" * 80) logprobs_per_step = output.outputs[0].logprobs if logprobs_per_step is None: print("Logprobs not returned. Check your SamplingParams.") continue print("\nLogprobs per generated token:") for step_idx, step_logprobs_dict in enumerate(logprobs_per_step[:10]): generated_token_info = None for token_id, logprob_obj in step_logprobs_dict.items(): if logprob_obj.rank == 1: generated_token_info = (token_id, logprob_obj.decoded_token) break if generated_token_info: token_id, token_text = generated_token_info print(f" Step {step_idx}:") print(f" - Generated Token: {token_id} ('{token_text}')") else: print(f" Step {step_idx}: (Could not find rank-1 token)") continue sorted_logprobs = sorted(step_logprobs_dict.values(), key=lambda x: x.rank) print(" - Top Logprobs:") for logprob_obj in sorted_logprobs: token_id = next(tid for tid, lp in step_logprobs_dict.items() if lp is logprob_obj) # token_text = logprob_obj.decoded_token logprob_value = logprob_obj.logprob rank = logprob_obj.rank print(f" - Rank {rank}: Token {token_id} ('{token_text}') -> Logprob: {logprob_value:.4f}") if rank == 1: first_10_logprobs_to_save.append(logprob_value) output_filename = './Qwen3-30B-A3B_logprobs_K100AI_fp16.json' with open(output_filename, 'w') as f: json.dump(first_10_logprobs_to_save, f, indent=2) print(f"成功将每个生成token的logprob写入到文件: {output_filename}") if __name__ == "__main__": parser = create_parser() args: dict = vars(parser.parse_args()) main(args)