Unverified Commit 122cdca5 authored by Reid's avatar Reid Committed by GitHub
Browse files

[Misc] refactor context extension (#19246)


Signed-off-by: default avatarreidliu41 <reid201711@gmail.com>
Co-authored-by: default avatarreidliu41 <reid201711@gmail.com>
parent cf02f9b2
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
This script demonstrates how to extend the context length
of a Qwen model using the YARN method (rope_scaling)
and run a simple chat example.
Usage:
python examples/offline_inference/context_extension.py
"""
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
rope_theta = 1000000
original_max_position_embeddings = 32768 def create_llm():
factor = 4.0 rope_theta = 1000000
original_max_position_embeddings = 32768
# Use yarn to extend context factor = 4.0
hf_overrides = {
"rope_theta": rope_theta, # Use yarn to extend context
"rope_scaling": { hf_overrides = {
"rope_type": "yarn", "rope_theta": rope_theta,
"factor": factor, "rope_scaling": {
"original_max_position_embeddings": original_max_position_embeddings, "rope_type": "yarn",
}, "factor": factor,
"max_model_len": int(original_max_position_embeddings * factor), "original_max_position_embeddings": original_max_position_embeddings,
} },
"max_model_len": int(original_max_position_embeddings * factor),
llm = LLM(model="Qwen/Qwen3-0.6B", hf_overrides=hf_overrides) }
sampling_params = SamplingParams( llm = LLM(model="Qwen/Qwen3-0.6B", hf_overrides=hf_overrides)
temperature=0.8, return llm
top_p=0.95,
max_tokens=128,
) def run_llm_chat(llm):
sampling_params = SamplingParams(
conversation = [ temperature=0.8,
{"role": "system", "content": "You are a helpful assistant"}, top_p=0.95,
{"role": "user", "content": "Hello"}, max_tokens=128,
{"role": "assistant", "content": "Hello! How can I assist you today?"}, )
]
outputs = llm.chat(conversation, sampling_params, use_tqdm=False) conversation = [
{"role": "system", "content": "You are a helpful assistant"},
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hello! How can I assist you today?"},
]
outputs = llm.chat(conversation, sampling_params, use_tqdm=False)
return outputs
def print_outputs(outputs): def print_outputs(outputs):
...@@ -44,4 +58,11 @@ def print_outputs(outputs): ...@@ -44,4 +58,11 @@ def print_outputs(outputs):
print("-" * 80) print("-" * 80)
print_outputs(outputs) def main():
llm = create_llm()
outputs = run_llm_chat(llm)
print_outputs(outputs)
if __name__ == "__main__":
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment