Commit 3f29b4a8 authored by 王敏's avatar 王敏
Browse files

[feat]添加qwen的lora单测

parent b4cf96af
......@@ -212,6 +212,11 @@ def phi2_lora_files():
# return snapshot_download(repo_id="isotr0py/phi-2-test-sql-lora")
return os.path.join(models_path_prefix, "isotr0py/phi-2-test-sql-lora")
@pytest.fixture(scope="session")
def qwen_lora_files():
# return snapshot_download(repo_id="jeeejeee/chatglm3-text2sql-spider")
return os.path.join(models_path_prefix, "customize/qwen-nl2dsl-lora")
@pytest.fixture(scope="session")
def long_context_lora_files_16k_1():
......
from typing import List
import os
import vllm
from vllm.lora.request import LoRARequest
from ..utils import models_path_prefix
MODEL_PATH = os.path.join(models_path_prefix, "Qwen/Qwen1.5-32B-Chat")
PROMPT_TEMPLATE = """<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n{query}<|im_end|>\n<|im_start|>assistant\n"""
def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
prompts = [
PROMPT_TEMPLATE.format(query="who are you?"),
PROMPT_TEMPLATE.format(
query="What is the capital city of China?"
),
PROMPT_TEMPLATE.format(
query="What is the longest river in the world?"
),
]
print(prompts)
sampling_params = vllm.SamplingParams(temperature=0, max_tokens=256)
outputs = llm.generate(
prompts,
sampling_params,
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
if lora_id else None)
# Print the outputs.
generated_texts: List[str] = []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text.strip()
generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
return generated_texts
def test_qwen_lora(qwen_lora_files):
llm = vllm.LLM(MODEL_PATH,
max_model_len=1024,
enable_lora=True,
max_loras=4,
max_lora_rank=64,
tensor_parallel_size=4,
trust_remote_code=True)
expected_lora_output = [
"I am a large language model created by Alibaba Cloud. I am called Qwen.",
"The capital city of China is Beijing.",
"The longest river in the world is the Nile, located in Africa. It stretches for approximately 4,135 miles (6,650 kilometers) from its source in the highlands of Rwanda, through Tanzania, Uganda, South Sudan, Sudan, and Egypt, before emptying into the Mediterranean Sea. The Nile is famous for its historical and cultural significance, particularly in relation to ancient Egyptian civilization.",
]
output1 = do_sample(llm, qwen_lora_files, lora_id=1)
for i in range(len(expected_lora_output)):
assert output1[i] == expected_lora_output[i]
output2 = do_sample(llm, qwen_lora_files, lora_id=2)
for i in range(len(expected_lora_output)):
assert output2[i] == expected_lora_output[i]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment