Unverified Commit 0f4fb19b authored by Ying Sheng's avatar Ying Sheng Committed by GitHub
Browse files

[Fix, LoRA] fix LoRA with updates in main (#1545)

parent 63ba2f8d
# launch server # launch server
# python -m sglang.launch_server --model mistralai/Mistral-7B-Instruct-v0.3 --lora-paths /home/ying/test_lora /home/ying/test_lora_1 /home/ying/test_lora_2 lora3=/home/ying/test_lora_3 lora4=/home/ying/test_lora_4 --disable-radix --disable-cuda-graph --max-loras-per-batch 4 # python -m sglang.launch_server --model mistralai/Mistral-7B-Instruct-v0.3 --lora-paths /home/ying/test_lora lora1=/home/ying/test_lora_1 lora2=/home/ying/test_lora_2 --disable-radix --disable-cuda-graph --max-loras-per-batch 4
# send requests # send requests
# lora_path[i] specifies the LoRA used for text[i], so make sure they have the same length # lora_path[i] specifies the LoRA used for text[i], so make sure they have the same length
...@@ -22,12 +22,12 @@ json_data = { ...@@ -22,12 +22,12 @@ json_data = {
"sampling_params": {"max_new_tokens": 32}, "sampling_params": {"max_new_tokens": 32},
"lora_path": [ "lora_path": [
"/home/ying/test_lora", "/home/ying/test_lora",
"/home/ying/test_lora_1", "lora1",
"/home/ying/test_lora_2", "lora2",
"lora3", "lora1",
"lora4", "lora2",
"/home/ying/test_lora", None,
"/home/ying/test_lora_1", None,
], ],
} }
response = requests.post( response = requests.post(
......
...@@ -28,18 +28,18 @@ from typing import Any, Dict, List, Optional, Tuple ...@@ -28,18 +28,18 @@ from typing import Any, Dict, List, Optional, Tuple
import safetensors.torch import safetensors.torch
import torch import torch
from torch import nn from torch import nn
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from vllm.model_executor.model_loader.loader import DefaultModelLoader from vllm.model_executor.model_loader.loader import DefaultModelLoader
from sglang.srt.layers.linear import (
ColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
......
...@@ -594,6 +594,16 @@ class ServerArgs: ...@@ -594,6 +594,16 @@ class ServerArgs:
"Please use sglang<=0.3.2 or wait for later updates." "Please use sglang<=0.3.2 or wait for later updates."
) )
if isinstance(self.lora_paths, list):
lora_paths = self.lora_paths
self.lora_paths = {}
for lora_path in lora_paths:
if "=" in lora_path:
name, path = lora_path.split("=", 1)
self.lora_paths[name] = path
else:
self.lora_paths[lora_path] = lora_path
def prepare_server_args(argv: List[str]) -> ServerArgs: def prepare_server_args(argv: List[str]) -> ServerArgs:
""" """
......
...@@ -97,9 +97,7 @@ class TestLoRA(unittest.TestCase): ...@@ -97,9 +97,7 @@ class TestLoRA(unittest.TestCase):
) )
with HFRunner( with HFRunner(
base_path, base_path, torch_dtype=torch_dtype, model_type="generation"
torch_dtype=torch_dtype,
is_generation=True,
) as hf_runner: ) as hf_runner:
hf_outputs = hf_runner.forward( hf_outputs = hf_runner.forward(
prompts, max_new_tokens=max_new_tokens, lora_paths=batch_lora_paths prompts, max_new_tokens=max_new_tokens, lora_paths=batch_lora_paths
...@@ -108,7 +106,7 @@ class TestLoRA(unittest.TestCase): ...@@ -108,7 +106,7 @@ class TestLoRA(unittest.TestCase):
with HFRunner( with HFRunner(
base_path, base_path,
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
is_generation=True, model_type="generation",
) as hf_runner: ) as hf_runner:
hf_no_lora_outputs = hf_runner.forward( hf_no_lora_outputs = hf_runner.forward(
prompts, max_new_tokens=max_new_tokens prompts, max_new_tokens=max_new_tokens
...@@ -118,7 +116,7 @@ class TestLoRA(unittest.TestCase): ...@@ -118,7 +116,7 @@ class TestLoRA(unittest.TestCase):
base_path, base_path,
tp_size=tp_size, tp_size=tp_size,
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
is_generation=True, model_type="generation",
) as srt_runner: ) as srt_runner:
srt_no_lora_outputs = srt_runner.forward( srt_no_lora_outputs = srt_runner.forward(
prompts, max_new_tokens=max_new_tokens prompts, max_new_tokens=max_new_tokens
...@@ -198,7 +196,7 @@ class TestLoRA(unittest.TestCase): ...@@ -198,7 +196,7 @@ class TestLoRA(unittest.TestCase):
base_path, base_path,
tp_size=tp_size, tp_size=tp_size,
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
is_generation=True, model_type="generation",
lora_paths=all_lora_paths, lora_paths=all_lora_paths,
max_loras_per_batch=3, max_loras_per_batch=3,
disable_cuda_graph=True, disable_cuda_graph=True,
...@@ -211,7 +209,7 @@ class TestLoRA(unittest.TestCase): ...@@ -211,7 +209,7 @@ class TestLoRA(unittest.TestCase):
with HFRunner( with HFRunner(
base_path, base_path,
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
is_generation=True, model_type="generation",
output_str_only=True, output_str_only=True,
) as hf_runner: ) as hf_runner:
hf_outputs = hf_runner.forward( hf_outputs = hf_runner.forward(
...@@ -237,7 +235,7 @@ class TestLoRA(unittest.TestCase): ...@@ -237,7 +235,7 @@ class TestLoRA(unittest.TestCase):
base_path, base_path,
tp_size=tp_size, tp_size=tp_size,
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
is_generation=True, model_type="generation",
) as srt_runner: ) as srt_runner:
srt_no_lora_outputs = srt_runner.forward( srt_no_lora_outputs = srt_runner.forward(
prompts, max_new_tokens=max_new_tokens prompts, max_new_tokens=max_new_tokens
...@@ -247,7 +245,7 @@ class TestLoRA(unittest.TestCase): ...@@ -247,7 +245,7 @@ class TestLoRA(unittest.TestCase):
base_path, base_path,
tp_size=tp_size, tp_size=tp_size,
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
is_generation=True, model_type="generation",
lora_paths=all_lora_paths, lora_paths=all_lora_paths,
) as srt_runner: ) as srt_runner:
srt_outputs = srt_runner.forward( srt_outputs = srt_runner.forward(
......
...@@ -7,7 +7,7 @@ suites = { ...@@ -7,7 +7,7 @@ suites = {
"minimal": [ "minimal": [
"models/test_embedding_models.py", "models/test_embedding_models.py",
"models/test_generation_models.py", "models/test_generation_models.py",
# "models/test_lora.py", "models/test_lora.py",
"models/test_reward_models.py", "models/test_reward_models.py",
"sampling/penaltylib", "sampling/penaltylib",
"test_chunked_prefill.py", "test_chunked_prefill.py",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment