Commit 1a930b5d authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.6.2-dev_wm' into 'v0.6.2-dev'

[feat]添加qwen的lora单测

See merge request dcutoolkit/deeplearing/vllm!57
parents b4cf96af 3f29b4a8
...@@ -212,6 +212,11 @@ def phi2_lora_files(): ...@@ -212,6 +212,11 @@ def phi2_lora_files():
# return snapshot_download(repo_id="isotr0py/phi-2-test-sql-lora") # return snapshot_download(repo_id="isotr0py/phi-2-test-sql-lora")
return os.path.join(models_path_prefix, "isotr0py/phi-2-test-sql-lora") return os.path.join(models_path_prefix, "isotr0py/phi-2-test-sql-lora")
@pytest.fixture(scope="session")
def qwen_lora_files():
# return snapshot_download(repo_id="jeeejeee/chatglm3-text2sql-spider")
return os.path.join(models_path_prefix, "customize/qwen-nl2dsl-lora")
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def long_context_lora_files_16k_1(): def long_context_lora_files_16k_1():
......
from typing import List
import os
import vllm
from vllm.lora.request import LoRARequest
from ..utils import models_path_prefix
MODEL_PATH = os.path.join(models_path_prefix, "Qwen/Qwen1.5-32B-Chat")
PROMPT_TEMPLATE = """<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n{query}<|im_end|>\n<|im_start|>assistant\n"""
def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
prompts = [
PROMPT_TEMPLATE.format(query="who are you?"),
PROMPT_TEMPLATE.format(
query="What is the capital city of China?"
),
PROMPT_TEMPLATE.format(
query="What is the longest river in the world?"
),
]
print(prompts)
sampling_params = vllm.SamplingParams(temperature=0, max_tokens=256)
outputs = llm.generate(
prompts,
sampling_params,
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
if lora_id else None)
# Print the outputs.
generated_texts: List[str] = []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text.strip()
generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
return generated_texts
def test_qwen_lora(qwen_lora_files):
llm = vllm.LLM(MODEL_PATH,
max_model_len=1024,
enable_lora=True,
max_loras=4,
max_lora_rank=64,
tensor_parallel_size=4,
trust_remote_code=True)
expected_lora_output = [
"I am a large language model created by Alibaba Cloud. I am called Qwen.",
"The capital city of China is Beijing.",
"The longest river in the world is the Nile, located in Africa. It stretches for approximately 4,135 miles (6,650 kilometers) from its source in the highlands of Rwanda, through Tanzania, Uganda, South Sudan, Sudan, and Egypt, before emptying into the Mediterranean Sea. The Nile is famous for its historical and cultural significance, particularly in relation to ancient Egyptian civilization.",
]
output1 = do_sample(llm, qwen_lora_files, lora_id=1)
for i in range(len(expected_lora_output)):
assert output1[i] == expected_lora_output[i]
output2 = do_sample(llm, qwen_lora_files, lora_id=2)
for i in range(len(expected_lora_output)):
assert output2[i] == expected_lora_output[i]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment