Unverified Commit 4d6ada94 authored by Swapnil Parekh's avatar Swapnil Parekh Committed by GitHub
Browse files

[CORE] Adding support for insertion of soft-tuned prompts (#4645)


Co-authored-by: default avatarSwapnil Parekh <swapnilp@ibm.com>
Co-authored-by: default avatarJoe G <joseph.granados@h2o.ai>
Co-authored-by: default avatarAntoni Baum <antoni.baum@protonmail.com>
parent a0550cbc
......@@ -111,6 +111,7 @@ mypy vllm/spec_decode --config-file pyproject.toml
mypy vllm/model_executor --config-file pyproject.toml
mypy vllm/lora --config-file pyproject.toml
mypy vllm/logging --config-file pyproject.toml
mypy vllm/prompt_adapter --config-file pyproject.toml
mypy tests --config-file pyproject.toml
......
......@@ -92,11 +92,10 @@ def batched_generate(
for input in inputs:
prompt, sampling_param, lora_req = input
# Add requests to the engine and run the engine
llm._validate_and_add_requests(
prompt,
sampling_param,
lora_request=lora_req,
)
llm._validate_and_add_requests(prompt,
sampling_param,
lora_request=lora_req,
prompt_adapter_request=None)
outputs = llm._run_engine(use_tqdm=True)
return [outputs[i].outputs[0].text.strip() for i in range(len(outputs))]
......
......@@ -127,37 +127,37 @@ def test_lora_model_manager(dist_init, dummy_model):
model, 2, 2, 2,
LoRAConfig(max_lora_rank=8, max_cpu_loras=3, max_loras=2))
assert all(x is None for x in manager.lora_index_to_id)
assert manager.add_lora(model_lora1)
assert manager.activate_lora(1)
assert manager.add_adapter(model_lora1)
assert manager.activate_adapter(1)
assert manager.lora_index_to_id[0] == 1
assert not manager.add_lora(model_lora1)
assert not manager.activate_lora(1)
assert manager.add_lora(model_lora2)
assert manager.activate_lora(2)
assert not manager.add_adapter(model_lora1)
assert not manager.activate_adapter(1)
assert manager.add_adapter(model_lora2)
assert manager.activate_adapter(2)
assert manager.lora_index_to_id[0] == 1
assert manager.lora_index_to_id[1] == 2
assert not manager.add_lora(model_lora2)
assert not manager.activate_lora(2)
assert manager.add_lora(model_lora3)
assert not manager.add_adapter(model_lora2)
assert not manager.activate_adapter(2)
assert manager.add_adapter(model_lora3)
assert manager.lora_index_to_id[0] == 1
assert manager.lora_index_to_id[1] == 2
with pytest.raises(ValueError):
assert manager.activate_lora(3)
assert manager.activate_adapter(3)
assert manager.lora_index_to_id[0] == 1
assert manager.lora_index_to_id[1] == 2
assert manager.remove_lora(model_lora2.id)
assert manager.remove_adapter(model_lora2.id)
assert manager.lora_index_to_id[1] is None
assert not manager.remove_lora(model_lora2.id)
assert manager.remove_lora(model_lora1.id)
assert not manager.remove_lora(model_lora1.id)
assert manager.add_lora(model_lora1)
assert not manager.remove_adapter(model_lora2.id)
assert manager.remove_adapter(model_lora1.id)
assert not manager.remove_adapter(model_lora1.id)
assert manager.add_adapter(model_lora1)
assert manager.lora_index_to_id[0] is None
assert manager.lora_index_to_id[1] is None
assert manager.add_lora(model_lora2)
assert manager.activate_lora(3)
assert manager.add_adapter(model_lora2)
assert manager.activate_adapter(3)
assert manager.lora_index_to_id[0] == 3
assert manager.lora_index_to_id[1] is None
assert manager.activate_lora(2)
assert manager.activate_adapter(2)
assert manager.lora_index_to_id[0] == 3
assert manager.lora_index_to_id[1] == 2
......@@ -173,70 +173,70 @@ def test_lora_lru_cache_model_manager(dist_init, dummy_model):
model, 2, 2, 2,
LoRAConfig(max_lora_rank=8, max_cpu_loras=3, max_loras=2))
assert all(x is None for x in manager.lora_index_to_id)
assert manager.add_lora(model_lora1)
assert manager.activate_lora(1)
assert manager.add_adapter(model_lora1)
assert manager.activate_adapter(1)
assert manager.lora_index_to_id[0] == 1
assert not manager.add_lora(model_lora1)
assert not manager.activate_lora(1)
assert manager.add_lora(model_lora2)
assert manager.activate_lora(2)
assert not manager.add_adapter(model_lora1)
assert not manager.activate_adapter(1)
assert manager.add_adapter(model_lora2)
assert manager.activate_adapter(2)
assert manager.lora_index_to_id[0] == 1
assert manager.lora_index_to_id[1] == 2
assert not manager.add_lora(model_lora2)
assert not manager.activate_lora(2)
assert manager.add_lora(model_lora3)
assert not manager.add_adapter(model_lora2)
assert not manager.activate_adapter(2)
assert manager.add_adapter(model_lora3)
assert manager.lora_index_to_id[0] == 1
assert manager.lora_index_to_id[1] == 2
assert manager.activate_lora(3)
assert manager.activate_adapter(3)
assert manager.lora_index_to_id[0] == 3
assert manager.lora_index_to_id[1] == 2
assert manager.remove_lora(model_lora2.id)
assert manager.remove_adapter(model_lora2.id)
assert manager.lora_index_to_id[1] is None
assert not manager.remove_lora(model_lora2.id)
assert manager.remove_lora(model_lora1.id)
assert not manager.remove_lora(model_lora1.id)
assert manager.add_lora(model_lora1)
assert manager.activate_lora(1)
assert not manager.remove_adapter(model_lora2.id)
assert manager.remove_adapter(model_lora1.id)
assert not manager.remove_adapter(model_lora1.id)
assert manager.add_adapter(model_lora1)
assert manager.activate_adapter(1)
assert manager.lora_index_to_id[0] == 3
assert manager.lora_index_to_id[1] == 1
assert manager.add_lora(model_lora2)
assert manager.deactivate_lora(3)
assert manager.add_adapter(model_lora2)
assert manager.deactivate_adapter(3)
assert manager.lora_index_to_id[0] is None
assert manager.lora_index_to_id[1] == 1
assert manager.activate_lora(2)
assert manager.activate_adapter(2)
assert manager.lora_index_to_id[0] == 2
assert manager.lora_index_to_id[1] == 1
assert manager.activate_lora(3)
assert manager.activate_adapter(3)
assert manager.lora_index_to_id[0] == 2
assert manager.lora_index_to_id[1] == 3
assert manager.pin_lora(2)
assert manager.pin_adapter(2)
assert manager.lora_index_to_id[0] == 2
assert manager.lora_index_to_id[1] == 3
assert manager.activate_lora(1)
assert manager.activate_adapter(1)
assert manager.lora_index_to_id[0] == 2
assert manager.lora_index_to_id[1] == 1
assert manager.deactivate_lora(2)
assert manager.deactivate_adapter(2)
assert manager.lora_index_to_id[0] is None
assert manager.lora_index_to_id[1] == 1
assert manager.activate_lora(3)
assert manager.activate_adapter(3)
assert manager.lora_index_to_id[0] == 3
assert manager.lora_index_to_id[1] == 1
assert manager.pin_lora(3)
assert manager.pin_lora(1)
assert manager.pin_adapter(3)
assert manager.pin_adapter(1)
with pytest.raises(RuntimeError):
assert manager.pin_lora(2)
assert manager.pin_adapter(2)
assert manager.lora_index_to_id[0] == 3
assert manager.lora_index_to_id[1] == 1
with pytest.raises(RuntimeError):
assert manager.activate_lora(2)
assert manager.activate_adapter(2)
assert manager.deactivate_lora(3)
assert manager.pin_lora(2)
assert manager.deactivate_adapter(3)
assert manager.pin_adapter(2)
assert manager.lora_index_to_id[0] == 2
assert manager.lora_index_to_id[1] == 1
assert manager.remove_lora(3)
assert manager.remove_adapter(3)
with pytest.raises(ValueError):
assert manager.pin_lora(3)
assert manager.pin_adapter(3)
def test_lru_lora_model_manager(dist_init, dummy_model):
......@@ -256,168 +256,169 @@ def test_lru_lora_model_manager(dist_init, dummy_model):
assert all(x is None for x in manager.lora_index_to_id)
# Add up to capacity
assert manager.add_lora(model_lora1)
assert manager.add_lora(model_lora2)
assert manager.activate_lora(1)
assert manager.activate_lora(2)
assert manager.add_adapter(model_lora1)
assert manager.add_adapter(model_lora2)
assert manager.activate_adapter(1)
assert manager.activate_adapter(2)
assert set(manager.list_loras()) == {1, 2}
assert set(manager.list_adapters()) == {1, 2}
assert manager.lora_index_to_id[0] == 1
assert manager.lora_index_to_id[1] == 2
# Add over capacity
assert manager.add_lora(model_lora3)
assert manager.add_lora(model_lora4)
assert manager.activate_lora(3)
assert manager.activate_lora(4)
assert manager.add_adapter(model_lora3)
assert manager.add_adapter(model_lora4)
assert manager.activate_adapter(3)
assert manager.activate_adapter(4)
assert set(manager.list_loras()) == {3, 4}
assert set(manager.list_adapters()) == {3, 4}
assert manager.lora_index_to_id[0] == 3
assert manager.lora_index_to_id[1] == 4
# Add 3 again to move it to the top and then add 2
# should return false since it's in already
assert not manager.add_lora(model_lora3)
assert not manager.activate_lora(3)
assert manager.add_lora(model_lora2)
assert manager.activate_lora(2)
assert not manager.add_adapter(model_lora3)
assert not manager.activate_adapter(3)
assert manager.add_adapter(model_lora2)
assert manager.activate_adapter(2)
assert set(manager.list_loras()) == {3, 2}
assert set(manager.list_adapters()) == {3, 2}
assert manager.lora_index_to_id[0] == 3
assert manager.lora_index_to_id[1] == 2
# Remove manually
assert manager.remove_lora(3)
assert not manager.remove_lora(3)
assert manager.remove_adapter(3)
assert not manager.remove_adapter(3)
assert set(manager.list_loras()) == {2}
assert set(manager.list_adapters()) == {2}
assert manager.lora_index_to_id[0] is None
assert manager.lora_index_to_id[1] == 2
assert manager.add_lora(model_lora3)
assert manager.activate_lora(3)
assert manager.add_lora(model_lora4)
assert manager.activate_lora(4)
assert manager.add_adapter(model_lora3)
assert manager.activate_adapter(3)
assert manager.add_adapter(model_lora4)
assert manager.activate_adapter(4)
assert set(manager.list_loras()) == {3, 4}
assert set(manager.list_adapters()) == {3, 4}
assert manager.lora_index_to_id[0] == 3
assert manager.lora_index_to_id[1] == 4
assert manager.remove_oldest_lora()
assert set(manager.list_loras()) == {4}
assert manager.remove_oldest_adapter()
assert set(manager.list_adapters()) == {4}
assert manager.lora_index_to_id[0] is None
assert manager.lora_index_to_id[1] == 4
assert manager.remove_oldest_lora()
assert set(manager.list_loras()) == set()
assert manager.remove_oldest_adapter()
assert set(manager.list_adapters()) == set()
assert all(x is None for x in manager.lora_index_to_id)
assert not manager.remove_oldest_lora()
assert set(manager.list_loras()) == set()
assert not manager.remove_oldest_adapter()
assert set(manager.list_adapters()) == set()
assert all(x is None for x in manager.lora_index_to_id)
# pinning
assert manager.add_lora(model_lora3)
assert manager.activate_lora(3)
assert manager.add_lora(model_lora4)
assert manager.activate_lora(4)
assert set(manager.list_loras()) == {3, 4}
assert manager.add_adapter(model_lora3)
assert manager.activate_adapter(3)
assert manager.add_adapter(model_lora4)
assert manager.activate_adapter(4)
assert set(manager.list_adapters()) == {3, 4}
with pytest.raises(ValueError):
assert manager.pin_lora(1)
assert manager.pin_lora(3)
assert manager.pin_adapter(1)
assert manager.pin_adapter(3)
# Remove manually
assert manager.remove_lora(3)
assert not manager.remove_lora(3)
assert manager.remove_adapter(3)
assert not manager.remove_adapter(3)
assert set(manager.list_loras()) == {4}
assert set(manager.list_adapters()) == {4}
assert manager.lora_index_to_id[0] is None
assert manager.lora_index_to_id[1] == 4
assert manager.add_lora(model_lora1)
assert manager.pin_lora(1)
assert manager.add_lora(model_lora2)
assert manager.activate_lora(2)
assert manager.add_adapter(model_lora1)
assert manager.pin_adapter(1)
assert manager.add_adapter(model_lora2)
assert manager.activate_adapter(2)
assert set(manager.list_loras()) == {1, 2}
assert set(manager.list_adapters()) == {1, 2}
assert manager.lora_index_to_id[0] == 1
assert manager.lora_index_to_id[1] == 2
assert manager.remove_oldest_lora()
assert set(manager.list_loras()) == {1}
assert manager.remove_oldest_adapter()
assert set(manager.list_adapters()) == {1}
assert manager.lora_index_to_id[0] == 1
assert manager.lora_index_to_id[1] is None
with pytest.raises(RuntimeError):
assert manager.remove_oldest_lora()
assert manager.remove_oldest_adapter()
assert set(manager.list_loras()) == {1}
assert set(manager.list_adapters()) == {1}
def test_lru_cache_worker_lora_manager(llama_2_7b_model_extra_embeddings,
sql_lora_files):
def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
sql_lora_files):
lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4)
worker_lora_manager = LRUCacheWorkerLoRAManager(
worker_adapter_manager = LRUCacheWorkerLoRAManager(
4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size -
lora_config.lora_extra_vocab_size, lora_config, torch.device("cuda"),
EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES)
worker_lora_manager.create_lora_manager(llama_2_7b_model_extra_embeddings)
worker_adapter_manager.create_lora_manager(
llama_2_7b_model_extra_embeddings)
mapping = LoRAMapping([], [])
worker_lora_manager.set_active_loras([
worker_adapter_manager.set_active_adapters([
LoRARequest("1", 1, sql_lora_files),
LoRARequest("2", 2, sql_lora_files)
], mapping)
assert worker_lora_manager.list_loras() == {1, 2}
assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1
assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 2
assert worker_adapter_manager.list_adapters() == {1, 2}
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
worker_lora_manager.set_active_loras([
worker_adapter_manager.set_active_adapters([
LoRARequest("1", 1, sql_lora_files),
LoRARequest("3", 3, sql_lora_files),
LoRARequest("4", 4, sql_lora_files)
], mapping)
assert worker_lora_manager.list_loras() == {1, 2, 3, 4}
assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1
assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 2
assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 3
assert worker_lora_manager._lora_manager.lora_index_to_id[3] == 4
assert worker_adapter_manager.list_adapters() == {1, 2, 3, 4}
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 3
assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4
worker_lora_manager.set_active_loras([
worker_adapter_manager.set_active_adapters([
LoRARequest("1", 1, sql_lora_files),
LoRARequest("2", 2, sql_lora_files),
LoRARequest("5", 5, sql_lora_files)
], mapping)
assert worker_lora_manager.list_loras() == {1, 2, 4, 5}
assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1
assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 2
assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 5
assert worker_lora_manager._lora_manager.lora_index_to_id[3] == 4
assert worker_adapter_manager.list_adapters() == {1, 2, 4, 5}
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5
assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4
worker_lora_manager.set_active_loras([
worker_adapter_manager.set_active_adapters([
LoRARequest("1", 1, sql_lora_files),
LoRARequest("1", 1, sql_lora_files),
LoRARequest("1", 1, sql_lora_files)
], mapping)
assert worker_lora_manager.list_loras() == {1, 2, 4, 5}
assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1
assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 2
assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 5
assert worker_lora_manager._lora_manager.lora_index_to_id[3] == 4
assert worker_adapter_manager.list_adapters() == {1, 2, 4, 5}
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5
assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4
worker_lora_manager.set_active_loras([
worker_adapter_manager.set_active_adapters([
LoRARequest("6", 6, sql_lora_files),
LoRARequest("7", 7, sql_lora_files),
LoRARequest("8", 8, sql_lora_files)
], mapping)
assert worker_lora_manager.list_loras() == {1, 6, 7, 8}
assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1
assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 7
assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 8
assert worker_lora_manager._lora_manager.lora_index_to_id[3] == 6
assert worker_adapter_manager.list_adapters() == {1, 6, 7, 8}
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 7
assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 8
assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 6
# Over capacity
with pytest.raises(RuntimeError):
worker_lora_manager.set_active_loras([
worker_adapter_manager.set_active_adapters([
LoRARequest("10", 10, sql_lora_files),
LoRARequest("11", 11, sql_lora_files),
LoRARequest("12", 12, sql_lora_files),
......@@ -426,68 +427,69 @@ def test_lru_cache_worker_lora_manager(llama_2_7b_model_extra_embeddings,
], mapping)
def test_worker_lora_manager(llama_2_7b_model_extra_embeddings,
sql_lora_files):
def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
sql_lora_files):
# Should remove every LoRA not specified in the request.
lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4)
worker_lora_manager = WorkerLoRAManager(
worker_adapter_manager = WorkerLoRAManager(
4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size -
lora_config.lora_extra_vocab_size, lora_config, torch.device("cuda"),
EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES)
worker_lora_manager.create_lora_manager(llama_2_7b_model_extra_embeddings)
worker_adapter_manager.create_lora_manager(
llama_2_7b_model_extra_embeddings)
mapping = LoRAMapping([], [])
worker_lora_manager.set_active_loras([
worker_adapter_manager.set_active_adapters([
LoRARequest("1", 1, sql_lora_files),
LoRARequest("2", 2, sql_lora_files)
], mapping)
assert worker_lora_manager.list_loras() == {1, 2}
assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1
assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 2
assert worker_adapter_manager.list_adapters() == {1, 2}
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
worker_lora_manager.set_active_loras([
worker_adapter_manager.set_active_adapters([
LoRARequest("1", 1, sql_lora_files),
LoRARequest("3", 3, sql_lora_files),
LoRARequest("4", 4, sql_lora_files)
], mapping)
assert worker_lora_manager.list_loras() == {1, 3, 4}
assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1
assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 3
assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 4
assert worker_adapter_manager.list_adapters() == {1, 3, 4}
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 3
assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 4
worker_lora_manager.set_active_loras([
worker_adapter_manager.set_active_adapters([
LoRARequest("1", 1, sql_lora_files),
LoRARequest("2", 2, sql_lora_files),
LoRARequest("5", 5, sql_lora_files)
], mapping)
assert worker_lora_manager.list_loras() == {1, 2, 5}
assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1
assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 2
assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 5
assert worker_adapter_manager.list_adapters() == {1, 2, 5}
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5
worker_lora_manager.set_active_loras([
worker_adapter_manager.set_active_adapters([
LoRARequest("1", 1, sql_lora_files),
LoRARequest("1", 1, sql_lora_files),
LoRARequest("1", 1, sql_lora_files)
], mapping)
assert worker_lora_manager.list_loras() == {1}
assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1
assert worker_lora_manager._lora_manager.lora_index_to_id[1] is None
assert worker_lora_manager._lora_manager.lora_index_to_id[2] is None
assert worker_adapter_manager.list_adapters() == {1}
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] is None
assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] is None
worker_lora_manager.set_active_loras([
worker_adapter_manager.set_active_adapters([
LoRARequest("6", 6, sql_lora_files),
LoRARequest("7", 7, sql_lora_files),
LoRARequest("8", 8, sql_lora_files)
], mapping)
assert worker_lora_manager.list_loras() == {6, 7, 8}
assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 8
assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 6
assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 7
assert worker_adapter_manager.list_adapters() == {6, 7, 8}
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 8
assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 6
assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 7
# Over capacity
with pytest.raises(RuntimeError):
worker_lora_manager.set_active_loras([
worker_adapter_manager.set_active_adapters([
LoRARequest("10", 10, sql_lora_files),
LoRARequest("11", 11, sql_lora_files),
LoRARequest("12", 12, sql_lora_files),
......@@ -525,8 +527,8 @@ def test_packed_loras(dist_init, dummy_model_gate_up):
assert isinstance(model.get_submodule("gate_up_proj"),
MergedColumnParallelLinearWithLoRA)
assert manager.add_lora(model_lora)
assert manager.add_lora(model_lora1)
assert manager.add_adapter(model_lora)
assert manager.add_adapter(model_lora1)
packed_lora = model_lora.get_lora("gate_up_proj")
assert packed_lora and isinstance(packed_lora, PackedLoRALayerWeights)
......
import pytest
import vllm
from vllm.prompt_adapter.request import PromptAdapterRequest
MODEL_PATH = "bigscience/bloomz-560m"
PA_PATH = 'stevhliu/bloomz-560m_PROMPT_TUNING_CAUSAL_LM'
def do_sample(llm, pa_name: str, pa_id: int):
prompts = [
"Tweet text : @nationalgridus I have no water and the bill is \
current and paid. Can you do something about this? Label : ",
"Tweet text : @nationalgridus Looks good thanks! Label : "
]
sampling_params = vllm.SamplingParams(temperature=0.0,
max_tokens=3,
stop_token_ids=[3])
outputs = llm.generate(prompts,
sampling_params,
prompt_adapter_request=PromptAdapterRequest(
pa_name, pa_id, PA_PATH, 8) if pa_id else None)
# Print the outputs.
generated_texts = []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text.strip()
generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
return generated_texts
@pytest.mark.parametrize("enforce_eager", [True, False])
def test_twitter_prompt_adapter(enforce_eager: bool):
llm = vllm.LLM(MODEL_PATH,
enforce_eager=enforce_eager,
enable_prompt_adapter=True,
max_prompt_adapter_token=8)
expected_output = ['complaint', 'no complaint']
assert do_sample(llm, "twitter_pa", pa_id=1) == expected_output
from vllm import EngineArgs, LLMEngine, SamplingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
MODEL_PATH = "bigscience/bloomz-560m"
pa_path = 'stevhliu/bloomz-560m_PROMPT_TUNING_CAUSAL_LM'
pa_path2 = 'swapnilbp/angry_tweet_ptune'
def do_sample(engine):
prompts = [
("Tweet text: I have complaints! Label: ",
SamplingParams(temperature=0.0, max_tokens=3, stop_token_ids=[3]),
PromptAdapterRequest("hate_speech", 1, pa_path2, 8)),
("Tweet text: I have no problems Label: ",
SamplingParams(temperature=0.0, max_tokens=3, stop_token_ids=[3]),
PromptAdapterRequest("hate_speech2", 2, pa_path2, 8)),
("Tweet text: I have complaints! Label: ",
SamplingParams(temperature=0.0, max_tokens=3), None),
("Tweet text: I have no problems Label: ",
SamplingParams(temperature=0.0, max_tokens=3, stop_token_ids=[3]),
PromptAdapterRequest("complain", 3, pa_path, 8)),
]
request_id = 0
results = set()
while prompts or engine.has_unfinished_requests():
if prompts:
prompt, sampling_params, pa_request = prompts.pop(0)
engine.add_request(str(request_id),
prompt,
sampling_params,
prompt_adapter_request=pa_request)
request_id += 1
request_outputs = engine.step()
for request_output in request_outputs:
if request_output.finished:
results.add(request_output.outputs[0].text)
return results
def test_multi_prompt_adapters():
engine_args = EngineArgs(model=MODEL_PATH,
max_prompt_adapters=3,
enable_prompt_adapter=True,
max_prompt_adapter_token=8)
engine = LLMEngine.from_engine_args(engine_args)
expected_output = {
' quot;I', 'hate speech', 'no complaint', 'not hate speech'
}
assert do_sample(engine) == expected_output
from huggingface_hub import snapshot_download
from vllm import EngineArgs, LLMEngine, SamplingParams
from vllm.lora.request import LoRARequest
from vllm.prompt_adapter.request import PromptAdapterRequest
MODEL_PATH = "meta-llama/Llama-2-7b-hf"
pa_path = snapshot_download(repo_id="swapnilbp/llama_tweet_ptune")
lora_path = snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test")
def do_sample(engine):
prompt_text = "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]" # noqa: E501
# first prompt with a prompt adapter and second without adapter
prompts = [
(prompt_text,
SamplingParams(temperature=0.0, max_tokens=100,
stop=["[/assistant]"]),
PromptAdapterRequest("hate_speech", 1, pa_path,
8), LoRARequest("sql_test", 1, lora_path)),
(prompt_text,
SamplingParams(temperature=0.0, max_tokens=100,
stop=["[/assistant]"]), None,
LoRARequest("sql_test", 1, lora_path)),
]
request_id = 0
results = set()
while prompts or engine.has_unfinished_requests():
if prompts:
prompt, sampling_params, pa_request, lora_request = prompts.pop(0)
engine.add_request(str(request_id),
prompt,
sampling_params,
prompt_adapter_request=pa_request,
lora_request=lora_request)
request_id += 1
request_outputs = engine.step()
for request_output in request_outputs:
if request_output.finished:
results.add(request_output.outputs[0].text)
return results
def test_lora_prompt_adapter():
engine_args = EngineArgs(model=MODEL_PATH,
enable_prompt_adapter=True,
enable_lora=True,
max_num_seqs=60,
max_prompt_adapter_token=8)
engine = LLMEngine.from_engine_args(engine_args)
result = do_sample(engine)
expected_output = {
" SELECT icao FROM table_name_74 WHERE airport = 'lilongwe international airport' " # noqa: E501
}
assert result == expected_output
......@@ -13,6 +13,7 @@ from vllm.lora.request import LoRARequest
from vllm.model_executor.utils import set_random_seed
from vllm.multimodal import MultiModalDataDict
from vllm.outputs import RequestOutput
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.sequence import Logprob
from vllm.usage.usage_lib import UsageContext
......@@ -92,6 +93,7 @@ class AsyncLLM:
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalDataDict] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
) -> List[RequestOutput]:
if prompts is None:
......
......@@ -23,6 +23,7 @@ def _create_model_runner(model: str, *args, **kwargs) -> ModelRunner:
cache_config=engine_config.cache_config,
load_config=engine_config.load_config,
lora_config=engine_config.lora_config,
prompt_adapter_config=engine_config.prompt_adapter_config,
is_driver_worker=True,
)
return model_runner
......
from dataclasses import dataclass
from typing import Tuple
@dataclass
class AdapterMapping:
# Per every token in input_ids:
index_mapping: Tuple[int, ...]
# Per sampled token:
prompt_mapping: Tuple[int, ...]
def __post_init__(self):
self.index_mapping = tuple(self.index_mapping)
self.prompt_mapping = tuple(self.prompt_mapping)
\ No newline at end of file
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Hashable, Optional, TypeVar
from torch import nn
from vllm.logger import init_logger
from vllm.utils import LRUCache
logger = init_logger(__name__)
class AdapterModel(ABC):
def __init__(self, model_id=None):
self.id = model_id
@abstractmethod
def from_local_checkpoint(cls, model_dir, model_id=None, **kwargs):
# Common initialization code
# Load weights or embeddings from local checkpoint
raise NotImplementedError("Subclasses must implement this method.")
T = TypeVar('T')
class AdapterLRUCache(LRUCache[T]):
def __init__(self, capacity: int, deactivate_fn: Callable[[Hashable],
None]):
super().__init__(capacity)
self.deactivate_fn = deactivate_fn
def _on_remove(self, key: Hashable, value: T):
logger.debug("Removing adapter int id: %d", key)
self.deactivate_fn(key)
return super()._on_remove(key, value)
class AdapterModelManager(ABC):
def __init__(
self,
model: nn.Module,
):
"""Create a AdapterModelManager and adapter for a given model.
Args:
model: the model to be adapted.
"""
self.model: nn.Module = model
self._registered_adapters: Dict[int, Any] = {}
# Dict instead of a Set for compatibility with LRUCache.
self._active_adapters: Dict[int, None] = {}
self.adapter_type = 'Adapter'
self._last_mapping = None
def __len__(self) -> int:
return len(self._registered_adapters)
@property
@abstractmethod
def adapter_slots(self):
...
@property
@abstractmethod
def capacity(self):
...
@abstractmethod
def activate_adapter(self, adapter_id: int) -> bool:
...
@abstractmethod
def deactivate_adapter(self, adapter_id: int) -> bool:
...
@abstractmethod
def add_adapter(self, adapter: Any) -> bool:
...
@abstractmethod
def set_adapter_mapping(self, mapping: Any) -> None:
...
@abstractmethod
def remove_adapter(self, adapter_id: int) -> bool:
...
@abstractmethod
def remove_all_adapters(self):
...
@abstractmethod
def get_adapter(self, adapter_id: int) -> Optional[Any]:
...
@abstractmethod
def list_adapters(self) -> Dict[int, Any]:
...
@abstractmethod
def pin_adapter(self, adapter_id: int) -> bool:
...
from abc import abstractmethod
from dataclasses import dataclass
@dataclass
class AdapterRequest:
"""
Base class for adapter requests.
"""
@property
@abstractmethod
def adapter_id(self):
...
def __post_init__(self):
if self.adapter_id < 1:
raise ValueError(f"id must be > 0, got {self.adapter_id}")
def __eq__(self, value: object) -> bool:
return isinstance(
value, self.__class__) and self.adapter_id == value.adapter_id
def __hash__(self) -> int:
return hash(self.adapter_id)
from typing import Any, Callable, Dict, Optional, Set
## model functions
def deactivate_adapter(adapter_id: int, active_adapters: Dict[int, None],
deactivate_func: Callable) -> bool:
if adapter_id in active_adapters:
deactivate_func(adapter_id)
active_adapters.pop(adapter_id)
return True
return False
def add_adapter(adapter: Any, registered_adapters: Dict[int, Any],
capacity: int, add_func: Callable) -> bool:
if adapter.id not in registered_adapters:
if len(registered_adapters) >= capacity:
raise RuntimeError('No free adapter slots.')
add_func(adapter)
registered_adapters[adapter.id] = adapter
return True
return False
def set_adapter_mapping(mapping: Any, last_mapping: Any,
set_mapping_func: Callable) -> Any:
if last_mapping != mapping:
set_mapping_func(mapping)
return mapping
return last_mapping
def remove_adapter(adapter_id: int, registered_adapters: Dict[int, Any],
deactivate_func: Callable) -> bool:
deactivate_func(adapter_id)
return bool(registered_adapters.pop(adapter_id, None))
def list_adapters(registered_adapters: Dict[int, Any]) -> Dict[int, Any]:
return dict(registered_adapters)
def get_adapter(adapter_id: int,
registered_adapters: Dict[int, Any]) -> Optional[Any]:
return registered_adapters.get(adapter_id, None)
## worker functions
def set_active_adapters_worker(requests: Set[Any], mapping: Optional[Any],
apply_adapters_func,
set_adapter_mapping_func) -> None:
apply_adapters_func(requests)
set_adapter_mapping_func(mapping)
def add_adapter_worker(adapter_request: Any, list_adapters_func,
load_adapter_func, add_adapter_func,
activate_adapter_func) -> bool:
if adapter_request.adapter_id in list_adapters_func():
return False
loaded_adapter = load_adapter_func(adapter_request)
loaded = add_adapter_func(loaded_adapter)
activate_adapter_func(loaded_adapter.id)
return loaded
def apply_adapters_worker(adapter_requests: Set[Any], list_adapters_func,
adapter_slots: int, remove_adapter_func,
add_adapter_func) -> None:
models_that_exist = list_adapters_func()
models_map = {
adapter_request.adapter_id: adapter_request
for adapter_request in adapter_requests if adapter_request
}
if len(models_map) > adapter_slots:
raise RuntimeError(
f"Number of requested models ({len(models_map)}) is greater "
f"than the number of GPU model slots "
f"({adapter_slots}).")
new_models = set(models_map)
models_to_add = new_models - models_that_exist
models_to_remove = models_that_exist - new_models
for adapter_id in models_to_remove:
remove_adapter_func(adapter_id)
for adapter_id in models_to_add:
add_adapter_func(models_map[adapter_id])
def list_adapters_worker(adapter_manager_list_adapters_func) -> Set[int]:
return set(adapter_manager_list_adapters_func())
from abc import ABC, abstractmethod
from typing import Any, Optional, Set
import torch
class AbstractWorkerManager(ABC):
def __init__(self, device: torch.device):
self.device = device
@property
@abstractmethod
def is_enabled(self) -> bool:
...
@abstractmethod
def set_active_adapters(self, requests: Set[Any],
mapping: Optional[Any]) -> None:
...
@abstractmethod
def add_adapter(self, adapter_request: Any) -> bool:
...
@abstractmethod
def remove_adapter(self, adapter_id: int) -> bool:
...
@abstractmethod
def remove_all_adapters(self):
...
@abstractmethod
def list_adapters(self) -> Set[int]:
...
......@@ -1285,6 +1285,39 @@ class LoRAConfig:
raise ValueError("LoRA is not supported with chunked prefill yet.")
@dataclass
class PromptAdapterConfig:
max_prompt_adapters: int
max_prompt_adapter_token: int
max_cpu_prompt_adapters: Optional[int] = None
prompt_adapter_dtype: Optional[torch.dtype] = None
def __post_init__(self):
library_name = 'peft'
try:
__import__(library_name)
except ImportError as e:
raise ImportError(
f"'{library_name}' is not installed for prompt adapter support."
f"Please install it using 'pip install {library_name}'."
) from e
if self.max_prompt_adapters < 1:
raise ValueError(f"max_prompt_adapters "
f"({self.max_prompt_adapters}) must be >= 1.")
if self.max_prompt_adapter_token == 0:
raise ValueError("max_prompt_adapter_token must be set.")
if self.max_cpu_prompt_adapters is None:
self.max_cpu_prompt_adapters = self.max_prompt_adapters
def verify_with_model_config(self, model_config: ModelConfig):
if self.prompt_adapter_dtype in (None, "auto"):
self.prompt_adapter_dtype = model_config.dtype
elif isinstance(self.prompt_adapter_dtype, str):
self.prompt_adapter_dtype = getattr(torch,
self.prompt_adapter_dtype)
@dataclass
class MultiModalConfig:
"""Configs the input data format and how models should run for
......@@ -1518,6 +1551,7 @@ class EngineConfig:
speculative_config: Optional[SpeculativeConfig]
decoding_config: Optional[DecodingConfig]
observability_config: Optional[ObservabilityConfig]
prompt_adapter_config: Optional[PromptAdapterConfig]
def __post_init__(self):
"""Verify configs are valid & consistent with each other.
......@@ -1529,6 +1563,9 @@ class EngineConfig:
self.lora_config.verify_with_model_config(self.model_config)
self.lora_config.verify_with_scheduler_config(
self.scheduler_config)
if self.prompt_adapter_config:
self.prompt_adapter_config.verify_with_model_config(
self.model_config)
def to_dict(self):
"""Return the configs as a dictionary, for use in **kwargs.
......
......@@ -11,6 +11,7 @@ from vllm.core.interfaces import AllocStatus, BlockSpaceManager
from vllm.core.policy import Policy, PolicyFactory
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
SequenceGroupMetadata, SequenceStatus)
......@@ -139,6 +140,8 @@ class SchedulerOutputs:
if self.num_loras > 0:
self._sort_by_lora_ids()
self.num_prompt_adapters: int = len(self.prompt_adapter_requests)
def is_empty(self) -> bool:
# NOTE: We do not consider the ignored sequence groups.
return (not self.scheduled_seq_groups and not self.blocks_to_swap_in
......@@ -157,6 +160,14 @@ class SchedulerOutputs:
if g.seq_group.lora_request is not None
}
@property
def prompt_adapter_requests(self) -> Set[PromptAdapterRequest]:
return {
g.seq_group.prompt_adapter_request
for g in self.scheduled_seq_groups
if g.seq_group.prompt_adapter_request is not None
}
@dataclass
class SchedulerRunningOutputs:
......@@ -1024,6 +1035,7 @@ class Scheduler:
# `multi_modal_data` will be None.
multi_modal_data=seq_group.multi_modal_data
if scheduler_outputs.num_prefill_groups > 0 else None,
prompt_adapter_request=seq_group.prompt_adapter_request,
)
seq_group_metadata_list.append(seq_group_metadata)
......
......@@ -7,8 +7,8 @@ from typing import List, Optional, Tuple, Union
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
EngineConfig, LoadConfig, LoRAConfig, ModelConfig,
MultiModalConfig, ObservabilityConfig, ParallelConfig,
SchedulerConfig, SpeculativeConfig,
TokenizerPoolConfig)
PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig, TokenizerPoolConfig)
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.utils import FlexibleArgumentParser
......@@ -66,6 +66,9 @@ class EngineArgs:
enable_lora: bool = False
max_loras: int = 1
max_lora_rank: int = 16
enable_prompt_adapter: bool = False
max_prompt_adapters: int = 1
max_prompt_adapter_token: int = 0
fully_sharded_loras: bool = False
lora_extra_vocab_size: int = 256
long_lora_scaling_factors: Optional[Tuple[float]] = None
......@@ -449,6 +452,17 @@ class EngineArgs:
'Enabling this will use the fully sharded layers. '
'At high sequence length, max rank or '
'tensor parallel size, this is likely faster.'))
parser.add_argument('--enable-prompt-adapter',
action='store_true',
help='If True, enable handling of PromptAdapters.')
parser.add_argument('--max-prompt-adapters',
type=int,
default=EngineArgs.max_prompt_adapters,
help='Max number of PromptAdapters in a batch.')
parser.add_argument('--max-prompt-adapter-token',
type=int,
default=EngineArgs.max_prompt_adapter_token,
help='Max number of PromptAdapters tokens')
parser.add_argument("--device",
type=str,
default=EngineArgs.device,
......@@ -726,6 +740,11 @@ class EngineArgs:
model_loader_extra_config=self.model_loader_extra_config,
)
prompt_adapter_config = PromptAdapterConfig(
max_prompt_adapters=self.max_prompt_adapters,
max_prompt_adapter_token=self.max_prompt_adapter_token) \
if self.enable_prompt_adapter else None
decoding_config = DecodingConfig(
guided_decoding_backend=self.guided_decoding_backend)
......@@ -751,6 +770,7 @@ class EngineArgs:
load_config=load_config,
decoding_config=decoding_config,
observability_config=observability_config,
prompt_adapter_config=prompt_adapter_config,
)
......
......@@ -18,6 +18,7 @@ from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.usage.usage_lib import UsageContext
......@@ -264,6 +265,7 @@ class _AsyncLLMEngine(LLMEngine):
request_id: str,
inputs: PromptInputs,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> LLMInputs:
if isinstance(inputs, str):
inputs = {"prompt": inputs}
......@@ -279,6 +281,12 @@ class _AsyncLLMEngine(LLMEngine):
else:
prompt_token_ids = inputs["prompt_token_ids"]
if prompt_adapter_request:
prompt_token_ids = [
0
] * prompt_adapter_request.prompt_adapter_num_virtual_tokens + \
prompt_token_ids
llm_inputs = LLMInputs(prompt_token_ids=prompt_token_ids,
prompt=inputs.get("prompt"),
multi_modal_data=inputs.get("multi_modal_data"))
......@@ -286,13 +294,14 @@ class _AsyncLLMEngine(LLMEngine):
return self.input_processor(llm_inputs)
async def add_request_async(
self,
request_id: str,
inputs: PromptInputs,
params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Dict[str, str]] = None,
self,
request_id: str,
inputs: PromptInputs,
params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Dict[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
) -> None:
if lora_request is not None and not self.lora_config:
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
......@@ -301,7 +310,10 @@ class _AsyncLLMEngine(LLMEngine):
arrival_time = time.time()
processed_inputs = await self.process_model_inputs_async(
request_id=request_id, inputs=inputs, lora_request=lora_request)
request_id=request_id,
inputs=inputs,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
self._add_processed_request(
request_id=request_id,
......@@ -309,6 +321,7 @@ class _AsyncLLMEngine(LLMEngine):
params=params,
arrival_time=arrival_time,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
trace_headers=trace_headers,
)
......@@ -627,6 +640,7 @@ class AsyncLLMEngine:
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Dict[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
) -> AsyncStream:
if self.log_requests:
if isinstance(inputs, str):
......@@ -669,7 +683,7 @@ class AsyncLLMEngine:
arrival_time=arrival_time,
lora_request=lora_request,
trace_headers=trace_headers,
)
prompt_adapter_request=prompt_adapter_request)
return stream
......@@ -680,6 +694,7 @@ class AsyncLLMEngine:
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Dict[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
) -> AsyncIterator[RequestOutput]:
"""Generate outputs for a request.
......@@ -695,6 +710,8 @@ class AsyncLLMEngine:
request_id: The unique id of the request.
lora_request: LoRA request to use for generation, if any.
trace_headers: OpenTelemetry trace headers.
prompt_adapter_request: Prompt Adapter request to use
for generation, if any.
Yields:
The output `RequestOutput` objects from the LLMEngine
......@@ -749,6 +766,7 @@ class AsyncLLMEngine:
sampling_params,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
):
yield LLMEngine.validate_output(output, RequestOutput)
......@@ -837,6 +855,7 @@ class AsyncLLMEngine:
*,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Dict[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> AsyncIterator[Union[RequestOutput, EmbeddingRequestOutput]]:
"""Common logic to process requests with SamplingParams or
PoolingParams."""
......@@ -849,6 +868,7 @@ class AsyncLLMEngine:
arrival_time=arrival_time,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
)
try:
......
......@@ -8,7 +8,8 @@ from transformers import PreTrainedTokenizer
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoadConfig,
LoRAConfig, ModelConfig, MultiModalConfig,
ObservabilityConfig, ParallelConfig, SchedulerConfig,
ObservabilityConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig)
from vllm.core.scheduler import (ScheduledSequenceGroup, Scheduler,
SchedulerOutputs)
......@@ -27,6 +28,7 @@ from vllm.lora.request import LoRARequest
from vllm.outputs import (EmbeddingRequestOutput, RequestOutput,
RequestOutputFactory)
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest,
PoolerOutput, SamplerOutput, Sequence,
......@@ -93,6 +95,8 @@ class LLMEngine:
decoding.
executor_class: The model executor class for managing distributed
execution.
prompt_adapter_config (Optional): The configuration related to serving
prompt adapters.
log_stats: Whether to log statistics.
usage_context: Specified entry point, used for usage info collection.
"""
......@@ -161,6 +165,7 @@ class LLMEngine:
speculative_config: Optional[SpeculativeConfig],
decoding_config: Optional[DecodingConfig],
observability_config: Optional[ObservabilityConfig],
prompt_adapter_config: Optional[PromptAdapterConfig],
executor_class: Type[ExecutorBase],
log_stats: bool,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
......@@ -222,6 +227,7 @@ class LLMEngine:
self.speculative_config = speculative_config
self.load_config = load_config
self.decoding_config = decoding_config or DecodingConfig()
self.prompt_adapter_config = prompt_adapter_config
self.observability_config = observability_config or ObservabilityConfig(
)
self.log_stats = log_stats
......@@ -250,6 +256,7 @@ class LLMEngine:
multimodal_config=multimodal_config,
speculative_config=speculative_config,
load_config=load_config,
prompt_adapter_config=prompt_adapter_config,
)
if not self.model_config.embedding_mode:
......@@ -282,6 +289,8 @@ class LLMEngine:
# Feature flags
"enable_lora":
bool(lora_config),
"enable_prompt_adapter":
bool(prompt_adapter_config),
"enable_prefix_caching":
cache_config.enable_prefix_caching,
"enforce_eager":
......@@ -376,7 +385,6 @@ class LLMEngine:
engine_config = engine_args.create_engine_config()
distributed_executor_backend = (
engine_config.parallel_config.distributed_executor_backend)
# Initialize the cluster and specify the executor class.
if engine_config.device_config.device_type == "neuron":
from vllm.executor.neuron_executor import NeuronExecutor
......@@ -409,7 +417,6 @@ class LLMEngine:
else:
from vllm.executor.gpu_executor import GPUExecutor
executor_class = GPUExecutor
# Create the LLM engine.
engine = cls(
**engine_config.to_dict(),
......@@ -470,6 +477,9 @@ class LLMEngine:
self.lora_config.verify_with_model_config(self.model_config)
self.lora_config.verify_with_scheduler_config(
self.scheduler_config)
if self.prompt_adapter_config:
self.prompt_adapter_config.verify_with_model_config(
self.model_config)
def _get_eos_token_id(
self, lora_request: Optional[LoRARequest]) -> Optional[int]:
......@@ -487,6 +497,7 @@ class LLMEngine:
params: Union[SamplingParams, PoolingParams],
arrival_time: float,
lora_request: Optional[LoRARequest],
prompt_adapter_request: Optional[PromptAdapterRequest],
trace_headers: Optional[Dict[str, str]] = None,
) -> None:
# Create the sequences.
......@@ -495,7 +506,7 @@ class LLMEngine:
eos_token_id = self._get_eos_token_id(lora_request)
seq = Sequence(seq_id, processed_inputs, block_size, eos_token_id,
lora_request)
lora_request, prompt_adapter_request)
# Create a SequenceGroup based on SamplingParams or PoolingParams
if isinstance(params, SamplingParams):
......@@ -506,7 +517,7 @@ class LLMEngine:
arrival_time=arrival_time,
lora_request=lora_request,
trace_headers=trace_headers,
)
prompt_adapter_request=prompt_adapter_request)
elif isinstance(params, PoolingParams):
seq_group = self._create_sequence_group_with_pooling(
request_id,
......@@ -514,7 +525,7 @@ class LLMEngine:
params,
arrival_time=arrival_time,
lora_request=lora_request,
)
prompt_adapter_request=prompt_adapter_request)
else:
raise ValueError(
"Either SamplingParams or PoolingParams must be provided.")
......@@ -535,6 +546,7 @@ class LLMEngine:
request_id: str,
inputs: PromptInputs,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> LLMInputs:
if isinstance(inputs, str):
inputs = {"prompt": inputs}
......@@ -549,6 +561,11 @@ class LLMEngine:
else:
prompt_token_ids = inputs["prompt_token_ids"]
if prompt_adapter_request:
prompt_token_ids = \
[0] * prompt_adapter_request.prompt_adapter_num_virtual_tokens\
+ prompt_token_ids
llm_inputs = LLMInputs(prompt_token_ids=prompt_token_ids,
prompt=inputs.get("prompt"),
multi_modal_data=inputs.get("multi_modal_data"))
......@@ -563,6 +580,7 @@ class LLMEngine:
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Dict[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> None:
"""Add a request to the engine's request pool.
......@@ -612,9 +630,11 @@ class LLMEngine:
if arrival_time is None:
arrival_time = time.time()
processed_inputs = self.process_model_inputs(request_id=request_id,
inputs=inputs,
lora_request=lora_request)
processed_inputs = self.process_model_inputs(
request_id=request_id,
inputs=inputs,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
self._add_processed_request(
request_id=request_id,
......@@ -622,6 +642,7 @@ class LLMEngine:
params=params,
arrival_time=arrival_time,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
trace_headers=trace_headers,
)
......@@ -633,6 +654,7 @@ class LLMEngine:
arrival_time: float,
lora_request: Optional[LoRARequest],
trace_headers: Optional[Dict[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> SequenceGroup:
"""Creates a SequenceGroup with SamplingParams."""
max_logprobs = self.get_model_config().max_logprobs
......@@ -658,7 +680,7 @@ class LLMEngine:
sampling_params=sampling_params,
lora_request=lora_request,
trace_headers=trace_headers,
)
prompt_adapter_request=prompt_adapter_request)
return seq_group
......@@ -669,16 +691,19 @@ class LLMEngine:
pooling_params: PoolingParams,
arrival_time: float,
lora_request: Optional[LoRARequest],
prompt_adapter_request: Optional[PromptAdapterRequest],
) -> SequenceGroup:
"""Creates a SequenceGroup with PoolingParams."""
# Defensive copy of PoolingParams, which are used by the pooler
pooling_params = pooling_params.clone()
# Create the sequence group.
seq_group = SequenceGroup(request_id=request_id,
seqs=[seq],
arrival_time=arrival_time,
lora_request=lora_request,
pooling_params=pooling_params)
seq_group = SequenceGroup(
request_id=request_id,
seqs=[seq],
arrival_time=arrival_time,
lora_request=lora_request,
pooling_params=pooling_params,
prompt_adapter_request=prompt_adapter_request)
return seq_group
def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
......@@ -1082,6 +1107,16 @@ class LLMEngine:
def pin_lora(self, lora_id: int) -> bool:
return self.model_executor.pin_lora(lora_id)
def add_prompt_adapter(
self, prompt_adapter_request: PromptAdapterRequest) -> bool:
return self.model_executor.add_prompt_adapter(prompt_adapter_request)
def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
return self.model_executor.remove_prompt_adapter(prompt_adapter_id)
def list_prompt_adapters(self) -> List[int]:
return self.model_executor.list_prompt_adapters()
def check_health(self) -> None:
if self.tokenizer:
self.tokenizer.check_health()
......
......@@ -13,6 +13,7 @@ from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer import get_cached_tokenizer
from vllm.usage.usage_lib import UsageContext
......@@ -255,6 +256,7 @@ class LLM:
prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> List[RequestOutput]:
"""Generates the completions for the input prompts.
......@@ -271,6 +273,8 @@ class LLM:
prompts and it is paired one by one with the prompt.
use_tqdm: Whether to use tqdm to display the progress bar.
lora_request: LoRA request to use for generation, if any.
prompt_adapter_request: Prompt Adapter request to use for
generation, if any.
Returns:
A list of `RequestOutput` objects containing the
......@@ -304,7 +308,7 @@ class LLM:
inputs=inputs,
params=sampling_params,
lora_request=lora_request,
)
prompt_adapter_request=prompt_adapter_request)
outputs = self._run_engine(use_tqdm=use_tqdm)
return LLMEngine.validate_outputs(outputs, RequestOutput)
......@@ -397,6 +401,7 @@ class LLM:
prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> List[EmbeddingRequestOutput]:
"""Generates the completions for the input prompts.
......@@ -412,6 +417,8 @@ class LLM:
use the default pooling parameters.
use_tqdm: Whether to use tqdm to display the progress bar.
lora_request: LoRA request to use for generation, if any.
prompt_adapter_request: Prompt Adapter request to use for
generation, if any.
Returns:
A list of `EmbeddingRequestOutput` objects containing the
......@@ -445,6 +452,7 @@ class LLM:
inputs=inputs,
params=pooling_params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
)
outputs = self._run_engine(use_tqdm=use_tqdm)
......@@ -504,6 +512,7 @@ class LLM:
params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
Sequence[PoolingParams]],
lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
prompt_adapter_request: Optional[PromptAdapterRequest],
) -> None:
if isinstance(inputs, (str, dict)):
# Convert a single prompt to a list.
......@@ -526,19 +535,23 @@ class LLM:
params[i] if isinstance(params, Sequence) else params,
lora_request=lora_request[i] if isinstance(
lora_request, Sequence) else lora_request,
)
prompt_adapter_request=prompt_adapter_request)
def _add_request(
self,
inputs: PromptInputs,
params: Union[SamplingParams, PoolingParams],
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
self,
inputs: PromptInputs,
params: Union[SamplingParams, PoolingParams],
lora_request: Optional[Union[List[LoRARequest],
LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
) -> None:
request_id = str(next(self.request_counter))
self.llm_engine.add_request(request_id,
inputs,
params,
lora_request=lora_request)
self.llm_engine.add_request(
request_id,
inputs,
params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
def _run_engine(
self, *, use_tqdm: bool
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment