Unverified Commit 9b945daa authored by Antoni Baum's avatar Antoni Baum Committed by GitHub
Browse files

[Experimental] Add multi-LoRA support (#1804)


Co-authored-by: default avatarChen Shen <scv119@gmail.com>
Co-authored-by: default avatarShreyas Krishnaswamy <shrekris@anyscale.com>
Co-authored-by: default avatarAvnish Narayan <avnish@anyscale.com>
parent 9c1352eb
import os
import random
import tempfile
from unittest.mock import patch
from vllm.lora.models import LoRAMapping
from vllm.lora.request import LoRARequest
from vllm.config import ModelConfig, ParallelConfig, SchedulerConfig, LoRAConfig
from vllm.worker.worker import Worker
@patch.dict(os.environ, {"RANK": "0"})
def test_worker_apply_lora(sql_lora_files):
worker = Worker(
model_config=ModelConfig(
"meta-llama/Llama-2-7b-hf",
"meta-llama/Llama-2-7b-hf",
tokenizer_mode="auto",
trust_remote_code=False,
download_dir=None,
load_format="dummy",
seed=0,
dtype="float16",
revision=None,
),
parallel_config=ParallelConfig(1, 1, False),
scheduler_config=SchedulerConfig(32, 32, 32, 256),
local_rank=0,
rank=0,
lora_config=LoRAConfig(max_lora_rank=8, max_cpu_loras=32,
max_loras=32),
distributed_init_method=f"file://{tempfile.mkstemp()[1]}",
)
worker.init_model()
worker.load_model()
worker.model_runner.set_active_loras([], LoRAMapping([], []))
assert worker.list_loras() == set()
n_loras = 32
lora_requests = [
LoRARequest(str(i + 1), i + 1, sql_lora_files) for i in range(n_loras)
]
worker.model_runner.set_active_loras(lora_requests, LoRAMapping([], []))
assert worker.list_loras() == {
lora_request.lora_int_id
for lora_request in lora_requests
}
for i in range(32):
random.seed(i)
iter_lora_requests = random.choices(lora_requests,
k=random.randint(1, n_loras))
random.shuffle(iter_lora_requests)
iter_lora_requests = iter_lora_requests[:-random.randint(0, n_loras)]
worker.model_runner.set_active_loras(iter_lora_requests,
LoRAMapping([], []))
assert worker.list_loras().issuperset(
{lora_request.lora_int_id
for lora_request in iter_lora_requests})
This diff is collapsed.
...@@ -19,10 +19,11 @@ class MockLogitsSampler(Sampler): ...@@ -19,10 +19,11 @@ class MockLogitsSampler(Sampler):
self.fake_logits = fake_logits self.fake_logits = fake_logits
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
with patch("vllm.model_executor.layers.sampler._prune_hidden_states", with patch(
lambda x, y: x), patch( "vllm.model_executor.layers.sampler._prune_hidden_states",
"vllm.model_executor.layers.sampler._get_logits", lambda x, y: x), patch(
lambda *args, **kwargs: self.fake_logits): "vllm.model_executor.layers.sampler.Sampler._get_logits",
lambda *args, **kwargs: self.fake_logits):
return super().forward(*args, **kwargs) return super().forward(*args, **kwargs)
...@@ -38,7 +39,7 @@ def _prepare_test( ...@@ -38,7 +39,7 @@ def _prepare_test(
device=input_tensor.device, device=input_tensor.device,
dtype=input_tensor.dtype) dtype=input_tensor.dtype)
sampler = MockLogitsSampler(32000, fake_logits) sampler = MockLogitsSampler(32000, fake_logits)
model_runner = ModelRunner(None, None, None) model_runner = ModelRunner(None, None, None, None)
return input_tensor, fake_logits, sampler, model_runner return input_tensor, fake_logits, sampler, model_runner
...@@ -266,7 +267,7 @@ def test_sampler_top_k_top_p(seed: int): ...@@ -266,7 +267,7 @@ def test_sampler_top_k_top_p(seed: int):
device=input_tensor.device, device=input_tensor.device,
dtype=input_tensor.dtype) dtype=input_tensor.dtype)
sampler = MockLogitsSampler(32000, fake_logits) sampler = MockLogitsSampler(32000, fake_logits)
model_runner = ModelRunner(None, None, None) model_runner = ModelRunner(None, None, None, None)
generation_model = GenerationMixin() generation_model = GenerationMixin()
generation_config = GenerationConfig(top_k=top_k, generation_config = GenerationConfig(top_k=top_k,
......
...@@ -83,8 +83,8 @@ def create_worker(cls: type, ...@@ -83,8 +83,8 @@ def create_worker(cls: type,
enforce_eager=enforce_eager, enforce_eager=enforce_eager,
) )
(model_config, cache_config, parallel_config, (model_config, cache_config, parallel_config, scheduler_config,
scheduler_config) = engine_args.create_engine_configs() _) = engine_args.create_engine_configs()
distributed_init_method = get_distributed_init_method( distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port()) get_ip(), get_open_port())
......
...@@ -6,7 +6,7 @@ from vllm.worker.model_runner import ModelRunner ...@@ -6,7 +6,7 @@ from vllm.worker.model_runner import ModelRunner
def test_prepare_prompt(): def test_prepare_prompt():
model_runner = ModelRunner(None, None, None) model_runner = ModelRunner(None, None, None, None)
model_runner.set_block_size(16) model_runner.set_block_size(16)
batch_size = random.randint(1, 256) batch_size = random.randint(1, 256)
...@@ -33,7 +33,7 @@ def test_prepare_prompt(): ...@@ -33,7 +33,7 @@ def test_prepare_prompt():
expected_selected_token_indices.append(selected_token_start_idx + expected_selected_token_indices.append(selected_token_start_idx +
prompt_len - 1) prompt_len - 1)
selected_token_start_idx += max_seq_len selected_token_start_idx += max_seq_len
input_tokens, input_positions, _, return_prompt_lens, _ = ( input_tokens, input_positions, _, return_prompt_lens, _, _, _, _ = (
model_runner._prepare_prompt(seq_group_metadata_list)) model_runner._prepare_prompt(seq_group_metadata_list))
assert return_prompt_lens == prompt_lens assert return_prompt_lens == prompt_lens
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list, sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment