Unverified Commit 0f4fb19b authored by Ying Sheng's avatar Ying Sheng Committed by GitHub
Browse files

[Fix, LoRA] fix LoRA with updates in main (#1545)

parent 63ba2f8d
# launch server
# python -m sglang.launch_server --model mistralai/Mistral-7B-Instruct-v0.3 --lora-paths /home/ying/test_lora /home/ying/test_lora_1 /home/ying/test_lora_2 lora3=/home/ying/test_lora_3 lora4=/home/ying/test_lora_4 --disable-radix --disable-cuda-graph --max-loras-per-batch 4
# python -m sglang.launch_server --model mistralai/Mistral-7B-Instruct-v0.3 --lora-paths /home/ying/test_lora lora1=/home/ying/test_lora_1 lora2=/home/ying/test_lora_2 --disable-radix --disable-cuda-graph --max-loras-per-batch 4
# send requests
# lora_path[i] specifies the LoRA used for text[i], so make sure they have the same length
......@@ -22,12 +22,12 @@ json_data = {
"sampling_params": {"max_new_tokens": 32},
"lora_path": [
"/home/ying/test_lora",
"/home/ying/test_lora_1",
"/home/ying/test_lora_2",
"lora3",
"lora4",
"/home/ying/test_lora",
"/home/ying/test_lora_1",
"lora1",
"lora2",
"lora1",
"lora2",
None,
None,
],
}
response = requests.post(
......
......@@ -28,18 +28,18 @@ from typing import Any, Dict, List, Optional, Tuple
import safetensors.torch
import torch
from torch import nn
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from vllm.model_executor.model_loader.loader import DefaultModelLoader
from sglang.srt.layers.linear import (
ColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
......
......@@ -594,6 +594,16 @@ class ServerArgs:
"Please use sglang<=0.3.2 or wait for later updates."
)
if isinstance(self.lora_paths, list):
lora_paths = self.lora_paths
self.lora_paths = {}
for lora_path in lora_paths:
if "=" in lora_path:
name, path = lora_path.split("=", 1)
self.lora_paths[name] = path
else:
self.lora_paths[lora_path] = lora_path
def prepare_server_args(argv: List[str]) -> ServerArgs:
"""
......
......@@ -97,9 +97,7 @@ class TestLoRA(unittest.TestCase):
)
with HFRunner(
base_path,
torch_dtype=torch_dtype,
is_generation=True,
base_path, torch_dtype=torch_dtype, model_type="generation"
) as hf_runner:
hf_outputs = hf_runner.forward(
prompts, max_new_tokens=max_new_tokens, lora_paths=batch_lora_paths
......@@ -108,7 +106,7 @@ class TestLoRA(unittest.TestCase):
with HFRunner(
base_path,
torch_dtype=torch_dtype,
is_generation=True,
model_type="generation",
) as hf_runner:
hf_no_lora_outputs = hf_runner.forward(
prompts, max_new_tokens=max_new_tokens
......@@ -118,7 +116,7 @@ class TestLoRA(unittest.TestCase):
base_path,
tp_size=tp_size,
torch_dtype=torch_dtype,
is_generation=True,
model_type="generation",
) as srt_runner:
srt_no_lora_outputs = srt_runner.forward(
prompts, max_new_tokens=max_new_tokens
......@@ -198,7 +196,7 @@ class TestLoRA(unittest.TestCase):
base_path,
tp_size=tp_size,
torch_dtype=torch_dtype,
is_generation=True,
model_type="generation",
lora_paths=all_lora_paths,
max_loras_per_batch=3,
disable_cuda_graph=True,
......@@ -211,7 +209,7 @@ class TestLoRA(unittest.TestCase):
with HFRunner(
base_path,
torch_dtype=torch_dtype,
is_generation=True,
model_type="generation",
output_str_only=True,
) as hf_runner:
hf_outputs = hf_runner.forward(
......@@ -237,7 +235,7 @@ class TestLoRA(unittest.TestCase):
base_path,
tp_size=tp_size,
torch_dtype=torch_dtype,
is_generation=True,
model_type="generation",
) as srt_runner:
srt_no_lora_outputs = srt_runner.forward(
prompts, max_new_tokens=max_new_tokens
......@@ -247,7 +245,7 @@ class TestLoRA(unittest.TestCase):
base_path,
tp_size=tp_size,
torch_dtype=torch_dtype,
is_generation=True,
model_type="generation",
lora_paths=all_lora_paths,
) as srt_runner:
srt_outputs = srt_runner.forward(
......
......@@ -7,7 +7,7 @@ suites = {
"minimal": [
"models/test_embedding_models.py",
"models/test_generation_models.py",
# "models/test_lora.py",
"models/test_lora.py",
"models/test_reward_models.py",
"sampling/penaltylib",
"test_chunked_prefill.py",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment