Unverified Commit 69e1d2fb authored by Antoni Baum's avatar Antoni Baum Committed by GitHub
Browse files

[Core] Refactor model loading code (#4097)

parent 05434764
...@@ -92,7 +92,7 @@ steps: ...@@ -92,7 +92,7 @@ steps:
parallelism: 4 parallelism: 4
- label: Tensorizer Test - label: Tensorizer Test
command: apt-get install curl libsodium23 && pytest -v -s tensorizer command: apt-get install curl libsodium23 && pytest -v -s tensorizer_loader
- label: Metrics Test - label: Metrics Test
command: pytest -v -s metrics command: pytest -v -s metrics
......
...@@ -11,7 +11,7 @@ from safetensors.torch import safe_open ...@@ -11,7 +11,7 @@ from safetensors.torch import safe_open
from vllm.model_executor.layers.quantization.schema import QuantParamSchema from vllm.model_executor.layers.quantization.schema import QuantParamSchema
# Adapted from vllm/model_executor/weight_utils.py # Adapted from vllm/model_executor/model_loader/weight_utils.py
# The main differences are that we add the NPZ format and simplify # The main differences are that we add the NPZ format and simplify
# its functionality drastically for our purposes (e.g. we assume that # its functionality drastically for our purposes (e.g. we assume that
# the quantized model exists locally and there is no need to download it) # the quantized model exists locally and there is no need to download it)
...@@ -71,7 +71,7 @@ def _prepare_hf_weights( ...@@ -71,7 +71,7 @@ def _prepare_hf_weights(
return hf_weights_files, use_safetensors return hf_weights_files, use_safetensors
# Adapted from vllm/model_executor/weight_utils.py # Adapted from vllm/model_executor/model_loader/weight_utils.py
def _hf_tensorfile_iterator(filename: str, load_format: str, def _hf_tensorfile_iterator(filename: str, load_format: str,
use_safetensors: bool): use_safetensors: bool):
if load_format == "npz": if load_format == "npz":
......
...@@ -16,8 +16,8 @@ from transformers import AutoConfig, PretrainedConfig ...@@ -16,8 +16,8 @@ from transformers import AutoConfig, PretrainedConfig
from vllm.distributed import initialize_model_parallel from vllm.distributed import initialize_model_parallel
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine from vllm.engine.llm_engine import LLMEngine
from vllm.model_executor.model_loader.tensorizer import TensorizerArgs
from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.tensorizer_loader import TensorizerArgs
# yapf conflicts with isort for this docstring # yapf conflicts with isort for this docstring
# yapf: disable # yapf: disable
......
...@@ -153,11 +153,11 @@ def llama_2_7b_engine_extra_embeddings() -> nn.Module: ...@@ -153,11 +153,11 @@ def llama_2_7b_engine_extra_embeddings() -> nn.Module:
cleanup() cleanup()
get_model_old = get_model get_model_old = get_model
def get_model_patched(model_config, device_config, **kwargs): def get_model_patched(*, model_config, device_config, **kwargs):
return get_model_old(model_config, kwargs["lora_config"] = LoRAConfig(max_loras=4, max_lora_rank=8)
device_config, return get_model_old(model_config=model_config,
lora_config=LoRAConfig(max_loras=4, device_config=device_config,
max_lora_rank=8)) **kwargs)
with patch("vllm.worker.model_runner.get_model", get_model_patched): with patch("vllm.worker.model_runner.get_model", get_model_patched):
engine = vllm.LLM("meta-llama/Llama-2-7b-hf", enable_lora=False) engine = vllm.LLM("meta-llama/Llama-2-7b-hf", enable_lora=False)
......
...@@ -3,8 +3,8 @@ import random ...@@ -3,8 +3,8 @@ import random
import tempfile import tempfile
from unittest.mock import patch from unittest.mock import patch
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ParallelConfig, SchedulerConfig) ModelConfig, ParallelConfig, SchedulerConfig)
from vllm.lora.models import LoRAMapping from vllm.lora.models import LoRAMapping
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.worker.worker import Worker from vllm.worker.worker import Worker
...@@ -18,12 +18,14 @@ def test_worker_apply_lora(sql_lora_files): ...@@ -18,12 +18,14 @@ def test_worker_apply_lora(sql_lora_files):
"meta-llama/Llama-2-7b-hf", "meta-llama/Llama-2-7b-hf",
tokenizer_mode="auto", tokenizer_mode="auto",
trust_remote_code=False, trust_remote_code=False,
download_dir=None,
load_format="dummy",
seed=0, seed=0,
dtype="float16", dtype="float16",
revision=None, revision=None,
), ),
load_config=LoadConfig(
download_dir=None,
load_format="dummy",
),
parallel_config=ParallelConfig(1, 1, False), parallel_config=ParallelConfig(1, 1, False),
scheduler_config=SchedulerConfig(32, 32, 32), scheduler_config=SchedulerConfig(32, 32, 32),
device_config=DeviceConfig("cuda"), device_config=DeviceConfig("cuda"),
......
...@@ -3,7 +3,7 @@ import os ...@@ -3,7 +3,7 @@ import os
import huggingface_hub.constants import huggingface_hub.constants
import pytest import pytest
from vllm.model_executor.weight_utils import enable_hf_transfer from vllm.model_executor.model_loader.weight_utils import enable_hf_transfer
def test_hf_transfer_auto_activation(): def test_hf_transfer_auto_activation():
......
...@@ -36,8 +36,6 @@ def test_auto_gptq(model_quant_type: str, ) -> None: ...@@ -36,8 +36,6 @@ def test_auto_gptq(model_quant_type: str, ) -> None:
model_path, model_path,
tokenizer_mode="auto", tokenizer_mode="auto",
trust_remote_code=False, trust_remote_code=False,
download_dir=None,
load_format="dummy",
seed=0, seed=0,
dtype="float16", dtype="float16",
revision=None, revision=None,
...@@ -49,8 +47,6 @@ def test_auto_gptq(model_quant_type: str, ) -> None: ...@@ -49,8 +47,6 @@ def test_auto_gptq(model_quant_type: str, ) -> None:
model_path, model_path,
tokenizer_mode="auto", tokenizer_mode="auto",
trust_remote_code=False, trust_remote_code=False,
download_dir=None,
load_format="dummy",
seed=0, seed=0,
dtype="float16", dtype="float16",
revision=None, revision=None,
......
...@@ -32,7 +32,12 @@ def _prepare_test( ...@@ -32,7 +32,12 @@ def _prepare_test(
1e-2, 1e-2,
dtype=input_tensor.dtype) dtype=input_tensor.dtype)
sampler = MockLogitsSampler(fake_logits) sampler = MockLogitsSampler(fake_logits)
model_runner = ModelRunner(None, None, None, None, None) model_runner = ModelRunner(model_config=None,
parallel_config=None,
scheduler_config=None,
device_config=None,
load_config=None,
lora_config=None)
return input_tensor, fake_logits, sampler, model_runner return input_tensor, fake_logits, sampler, model_runner
...@@ -591,7 +596,12 @@ def test_sampler_top_k_top_p(seed: int, device: str): ...@@ -591,7 +596,12 @@ def test_sampler_top_k_top_p(seed: int, device: str):
device=input_tensor.device, device=input_tensor.device,
dtype=input_tensor.dtype) dtype=input_tensor.dtype)
sampler = MockLogitsSampler(fake_logits) sampler = MockLogitsSampler(fake_logits)
model_runner = ModelRunner(None, None, None, None, None) model_runner = ModelRunner(model_config=None,
parallel_config=None,
scheduler_config=None,
device_config=None,
load_config=None,
lora_config=None)
generation_model = GenerationMixin() generation_model = GenerationMixin()
generation_config = GenerationConfig(top_k=top_k, generation_config = GenerationConfig(top_k=top_k,
......
...@@ -118,6 +118,7 @@ def create_worker(cls: type, ...@@ -118,6 +118,7 @@ def create_worker(cls: type,
scheduler_config=engine_config.scheduler_config, scheduler_config=engine_config.scheduler_config,
device_config=engine_config.device_config, device_config=engine_config.device_config,
cache_config=engine_config.cache_config, cache_config=engine_config.cache_config,
load_config=engine_config.load_config,
local_rank=0, local_rank=0,
rank=0, rank=0,
distributed_init_method=distributed_init_method, distributed_init_method=distributed_init_method,
......
...@@ -16,8 +16,8 @@ from transformers import AutoConfig, PretrainedConfig ...@@ -16,8 +16,8 @@ from transformers import AutoConfig, PretrainedConfig
from vllm.distributed import initialize_model_parallel from vllm.distributed import initialize_model_parallel
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine from vllm.engine.llm_engine import LLMEngine
from vllm.model_executor.model_loader.tensorizer import TensorizerArgs
from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.tensorizer_loader import TensorizerArgs
# yapf conflicts with isort for this docstring # yapf conflicts with isort for this docstring
# yapf: disable # yapf: disable
...@@ -74,7 +74,7 @@ def parse_args(): ...@@ -74,7 +74,7 @@ def parse_args():
"extremely quickly. Tensor encryption and decryption is " "extremely quickly. Tensor encryption and decryption is "
"also supported, although libsodium must be installed to " "also supported, although libsodium must be installed to "
"use it.") "use it.")
parser = EngineArgs.add_cli_args(parser) parser = TensorizerArgs.add_cli_args(EngineArgs.add_cli_args(parser))
subparsers = parser.add_subparsers(dest='command') subparsers = parser.add_subparsers(dest='command')
serialize_parser = subparsers.add_parser( serialize_parser = subparsers.add_parser(
......
import gc import gc
import json
import os
import subprocess import subprocess
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import openai
import pytest import pytest
import ray
import torch import torch
from tests.entrypoints.test_openai_server import ServerRunner from tests.entrypoints.test_openai_server import ServerRunner
from vllm import SamplingParams from vllm import SamplingParams
from vllm.config import TensorizerConfig from vllm.model_executor.model_loader.tensorizer import (
from vllm.model_executor.tensorizer_loader import ( EncryptionParams, TensorizerConfig, TensorSerializer,
EncryptionParams, TensorSerializer, is_vllm_serialized_tensorizer, is_vllm_serialized_tensorizer, load_with_tensorizer, open_stream)
load_with_tensorizer, open_stream)
prompts = [ prompts = [
"Hello, my name is", "Hello, my name is",
...@@ -22,6 +25,8 @@ prompts = [ ...@@ -22,6 +25,8 @@ prompts = [
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, seed=0) sampling_params = SamplingParams(temperature=0.8, top_p=0.95, seed=0)
model_ref = "facebook/opt-125m" model_ref = "facebook/opt-125m"
tensorize_model_for_testing_script = os.path.join(
os.path.dirname(__file__), "tensorize_vllm_model_for_testing.py")
def is_curl_installed(): def is_curl_installed():
...@@ -38,7 +43,7 @@ def tensorizer_config(): ...@@ -38,7 +43,7 @@ def tensorizer_config():
return config return config
@patch('vllm.model_executor.tensorizer_loader.TensorizerAgent') @patch('vllm.model_executor.model_loader.tensorizer.TensorizerAgent')
def test_load_with_tensorizer(mock_agent, tensorizer_config): def test_load_with_tensorizer(mock_agent, tensorizer_config):
mock_linear_method = MagicMock() mock_linear_method = MagicMock()
mock_agent_instance = mock_agent.return_value mock_agent_instance = mock_agent.return_value
...@@ -81,11 +86,13 @@ def test_deserialized_vllm_model_has_same_outputs(vllm_runner, tmp_path): ...@@ -81,11 +86,13 @@ def test_deserialized_vllm_model_has_same_outputs(vllm_runner, tmp_path):
del vllm_model, model del vllm_model, model
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
loaded_vllm_model = vllm_runner(model_ref, loaded_vllm_model = vllm_runner(
model_ref,
load_format="tensorizer", load_format="tensorizer",
tensorizer_uri=model_path, model_loader_extra_config=TensorizerConfig(tensorizer_uri=model_path,
num_readers=1, num_readers=1,
vllm_tensorized=True) vllm_tensorized=True),
)
deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params) deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params)
# Assumes SamplingParams being seeded ensures the outputs are deterministic # Assumes SamplingParams being seeded ensures the outputs are deterministic
...@@ -97,14 +104,14 @@ def test_can_deserialize_s3(vllm_runner): ...@@ -97,14 +104,14 @@ def test_can_deserialize_s3(vllm_runner):
model_ref = "EleutherAI/pythia-1.4b" model_ref = "EleutherAI/pythia-1.4b"
tensorized_path = f"s3://tensorized/{model_ref}/fp16/model.tensors" tensorized_path = f"s3://tensorized/{model_ref}/fp16/model.tensors"
loaded_hf_model = vllm_runner( loaded_hf_model = vllm_runner(model_ref,
model_ref,
tensorizer_uri=tensorized_path,
load_format="tensorizer", load_format="tensorizer",
model_loader_extra_config=TensorizerConfig(
tensorizer_uri=tensorized_path,
num_readers=1, num_readers=1,
vllm_tensorized=False, vllm_tensorized=False,
s3_endpoint="object.ord1.coreweave.com", s3_endpoint="object.ord1.coreweave.com",
) ))
deserialized_outputs = loaded_hf_model.generate(prompts, sampling_params) deserialized_outputs = loaded_hf_model.generate(prompts, sampling_params)
...@@ -131,11 +138,12 @@ def test_deserialized_encrypted_vllm_model_has_same_outputs( ...@@ -131,11 +138,12 @@ def test_deserialized_encrypted_vllm_model_has_same_outputs(
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
loaded_vllm_model = vllm_runner(model_ref, loaded_vllm_model = vllm_runner(model_ref,
tensorizer_uri=model_path,
load_format="tensorizer", load_format="tensorizer",
model_loader_extra_config=TensorizerConfig(
tensorizer_uri=model_path,
encryption_keyfile=key_path, encryption_keyfile=key_path,
num_readers=1, num_readers=1,
vllm_tensorized=True) vllm_tensorized=True))
deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params) deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params)
...@@ -156,10 +164,11 @@ def test_deserialized_hf_model_has_same_outputs(hf_runner, vllm_runner, ...@@ -156,10 +164,11 @@ def test_deserialized_hf_model_has_same_outputs(hf_runner, vllm_runner,
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
loaded_hf_model = vllm_runner(model_ref, loaded_hf_model = vllm_runner(model_ref,
tensorizer_uri=model_path,
load_format="tensorizer", load_format="tensorizer",
model_loader_extra_config=TensorizerConfig(
tensorizer_uri=model_path,
num_readers=1, num_readers=1,
vllm_tensorized=False) vllm_tensorized=False))
deserialized_outputs = loaded_hf_model.generate_greedy( deserialized_outputs = loaded_hf_model.generate_greedy(
prompts, max_tokens=max_tokens) prompts, max_tokens=max_tokens)
...@@ -190,10 +199,12 @@ def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path): ...@@ -190,10 +199,12 @@ def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path):
torch.cuda.empty_cache() torch.cuda.empty_cache()
loaded_vllm_model = vllm_runner( loaded_vllm_model = vllm_runner(
model_ref, model_ref,
tensorizer_uri=model_path,
load_format="tensorizer", load_format="tensorizer",
model_loader_extra_config=TensorizerConfig(
tensorizer_uri=model_path,
num_readers=1, num_readers=1,
vllm_tensorized=True, vllm_tensorized=True,
),
enable_lora=True, enable_lora=True,
max_loras=1, max_loras=1,
max_lora_rank=8, max_lora_rank=8,
...@@ -208,16 +219,18 @@ def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path): ...@@ -208,16 +219,18 @@ def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path):
def test_load_without_tensorizer_load_format(vllm_runner): def test_load_without_tensorizer_load_format(vllm_runner):
with pytest.raises(ValueError): with pytest.raises(ValueError):
vllm_runner(model_ref, tensorizer_uri="test") vllm_runner(model_ref,
model_loader_extra_config=TensorizerConfig(
tensorizer_uri="test", vllm_tensorized=False))
@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed") @pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
def test_tensorize_vllm_model(tmp_path): def test_tensorize_vllm_model(tmp_path):
# Test serialize command # Test serialize command
serialize_args = [ serialize_args = [
"python3", "tensorizer/tensorize_vllm_model_for_testing.py", "--model", "python3", tensorize_model_for_testing_script, "--model", model_ref,
model_ref, "--dtype", "float16", "serialize", "--serialized-directory", "--dtype", "float16", "serialize", "--serialized-directory", tmp_path,
tmp_path, "--suffix", "tests" "--suffix", "tests"
] ]
result = subprocess.run(serialize_args, capture_output=True, text=True) result = subprocess.run(serialize_args, capture_output=True, text=True)
print(result.stdout) # Print the output of the serialize command print(result.stdout) # Print the output of the serialize command
...@@ -229,8 +242,8 @@ def test_tensorize_vllm_model(tmp_path): ...@@ -229,8 +242,8 @@ def test_tensorize_vllm_model(tmp_path):
# Test deserialize command # Test deserialize command
deserialize_args = [ deserialize_args = [
"python3", "tensorizer/tensorize_vllm_model_for_testing.py", "--model", "python3", tensorize_model_for_testing_script, "--model", model_ref,
model_ref, "--dtype", "float16", "deserialize", "--path-to-tensors", "--dtype", "float16", "deserialize", "--path-to-tensors",
path_to_tensors path_to_tensors
] ]
result = subprocess.run(deserialize_args, capture_output=True, text=True) result = subprocess.run(deserialize_args, capture_output=True, text=True)
...@@ -242,9 +255,9 @@ def test_tensorize_vllm_model(tmp_path): ...@@ -242,9 +255,9 @@ def test_tensorize_vllm_model(tmp_path):
def test_openai_apiserver_with_tensorizer(tmp_path): def test_openai_apiserver_with_tensorizer(tmp_path):
## Serialize model ## Serialize model
serialize_args = [ serialize_args = [
"python3", "tensorizer/tensorize_vllm_model_for_testing.py", "--model", "python3", tensorize_model_for_testing_script, "--model", model_ref,
model_ref, "--dtype", "float16", "serialize", "--serialized-directory", "--dtype", "float16", "serialize", "--serialized-directory", tmp_path,
tmp_path, "--suffix", "tests" "--suffix", "tests"
] ]
result = subprocess.run(serialize_args, capture_output=True, text=True) result = subprocess.run(serialize_args, capture_output=True, text=True)
print(result.stdout) # Print the output of the serialize command print(result.stdout) # Print the output of the serialize command
...@@ -253,25 +266,47 @@ def test_openai_apiserver_with_tensorizer(tmp_path): ...@@ -253,25 +266,47 @@ def test_openai_apiserver_with_tensorizer(tmp_path):
f"\n{result.stdout}\n{result.stderr}") f"\n{result.stdout}\n{result.stderr}")
path_to_tensors = f"{tmp_path}/vllm/{model_ref}/tests/model.tensors" path_to_tensors = f"{tmp_path}/vllm/{model_ref}/tests/model.tensors"
model_loader_extra_config = {
"tensorizer_uri": path_to_tensors,
"vllm_tensorized": True
}
## Start OpenAI API server ## Start OpenAI API server
openai_args = [ openai_args = [
"--model", model_ref, "--dtype", "float16", "--load-format", "--model", model_ref, "--dtype", "float16", "--load-format",
"tensorizer", "--tensorizer-uri", path_to_tensors, "--vllm-tensorized", "tensorizer", "--model-loader-extra-config",
"--port", "8000" json.dumps(model_loader_extra_config), "--port", "8000"
] ]
server = ServerRunner.remote(openai_args) server = ServerRunner.remote(openai_args)
assert ray.get(server.ready.remote())
print("Server ready.") print("Server ready.")
assert server.ready.remote()
client = openai.OpenAI(
base_url="http://localhost:8000/v1",
api_key="token-abc123",
)
completion = client.completions.create(model=model_ref,
prompt="Hello, my name is",
max_tokens=5,
temperature=0.0)
assert completion.id is not None
assert completion.choices is not None and len(completion.choices) == 1
assert completion.choices[0].text is not None and len(
completion.choices[0].text) >= 5
assert completion.choices[0].finish_reason == "length"
assert completion.usage == openai.types.CompletionUsage(
completion_tokens=5, prompt_tokens=6, total_tokens=11)
def test_raise_value_error_on_invalid_load_format(vllm_runner): def test_raise_value_error_on_invalid_load_format(vllm_runner):
with pytest.raises(ValueError): with pytest.raises(ValueError):
vllm_runner(model_ref, vllm_runner(model_ref,
load_format="safetensors", load_format="safetensors",
tensorizer_uri="test") model_loader_extra_config=TensorizerConfig(
tensorizer_uri="test", vllm_tensorized=False))
def test_tensorizer_with_tp(vllm_runner): def test_tensorizer_with_tp(vllm_runner):
...@@ -281,22 +316,12 @@ def test_tensorizer_with_tp(vllm_runner): ...@@ -281,22 +316,12 @@ def test_tensorizer_with_tp(vllm_runner):
vllm_runner( vllm_runner(
model_ref, model_ref,
tensorizer_uri=tensorized_path,
load_format="tensorizer", load_format="tensorizer",
model_loader_extra_config=TensorizerConfig(
tensorizer_uri=tensorized_path,
num_readers=1, num_readers=1,
vllm_tensorized=False, vllm_tensorized=False,
s3_endpoint="object.ord1.coreweave.com", s3_endpoint="object.ord1.coreweave.com",
),
tensor_parallel_size=2, tensor_parallel_size=2,
) )
@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
def test_tensorizer_warn_quant(tmp_path):
model_ref = "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit"
serialize_args = [
"python3", "tensorizer/tensorize_vllm_model_for_testing.py", "--model",
model_ref, "--quantization", "gptq", "--tensorizer-uri", "test",
"serialize", "--serialized-directory", tmp_path, "--suffix", "tests"
]
result = subprocess.run(serialize_args, capture_output=True, text=True)
assert 'PerformanceWarning' in result.stderr
...@@ -11,8 +11,6 @@ def test_get_sliding_window(): ...@@ -11,8 +11,6 @@ def test_get_sliding_window():
"Qwen/Qwen1.5-7B", "Qwen/Qwen1.5-7B",
tokenizer_mode="auto", tokenizer_mode="auto",
trust_remote_code=False, trust_remote_code=False,
download_dir=None,
load_format="dummy",
seed=0, seed=0,
dtype="float16", dtype="float16",
revision=None, revision=None,
...@@ -30,8 +28,6 @@ def test_get_sliding_window(): ...@@ -30,8 +28,6 @@ def test_get_sliding_window():
"mistralai/Mistral-7B-v0.1", "mistralai/Mistral-7B-v0.1",
tokenizer_mode="auto", tokenizer_mode="auto",
trust_remote_code=False, trust_remote_code=False,
download_dir=None,
load_format="dummy",
seed=0, seed=0,
dtype="float16", dtype="float16",
revision=None, revision=None,
......
...@@ -37,7 +37,12 @@ def _prepare_test( ...@@ -37,7 +37,12 @@ def _prepare_test(
1e-2, 1e-2,
dtype=input_tensor.dtype) dtype=input_tensor.dtype)
logits_processor = MockLogitsProcessor(32000, 0.5, fake_logits) logits_processor = MockLogitsProcessor(32000, 0.5, fake_logits)
model_runner = ModelRunner(None, None, None, None, None) model_runner = ModelRunner(model_config=None,
parallel_config=None,
scheduler_config=None,
device_config=None,
load_config=None,
lora_config=None)
return input_tensor, fake_logits, logits_processor, model_runner return input_tensor, fake_logits, logits_processor, model_runner
......
...@@ -12,7 +12,12 @@ def test_prepare_prompt(batch_size): ...@@ -12,7 +12,12 @@ def test_prepare_prompt(batch_size):
100000, 100000,
100000, 100000,
enable_chunked_prefill=False) enable_chunked_prefill=False)
model_runner = ModelRunner(None, None, scheduler_config, None, None) model_runner = ModelRunner(model_config=None,
parallel_config=None,
scheduler_config=scheduler_config,
device_config=None,
load_config=None,
lora_config=None)
model_runner.set_block_size(16) model_runner.set_block_size(16)
prompt_lens = [] prompt_lens = []
...@@ -118,8 +123,6 @@ def test_prepare_decode_cuda_graph(batch_size): ...@@ -118,8 +123,6 @@ def test_prepare_decode_cuda_graph(batch_size):
"facebook/opt-125m", "facebook/opt-125m",
tokenizer_mode="auto", tokenizer_mode="auto",
trust_remote_code=False, trust_remote_code=False,
download_dir=None,
load_format="dummy",
seed=0, seed=0,
dtype="float16", dtype="float16",
revision=None, revision=None,
...@@ -129,8 +132,12 @@ def test_prepare_decode_cuda_graph(batch_size): ...@@ -129,8 +132,12 @@ def test_prepare_decode_cuda_graph(batch_size):
100000, 100000,
100000, 100000,
enable_chunked_prefill=False) enable_chunked_prefill=False)
model_runner = ModelRunner(model_config, None, scheduler_config, None, model_runner = ModelRunner(model_config=model_config,
None) parallel_config=None,
scheduler_config=scheduler_config,
device_config=None,
load_config=None,
lora_config=None)
model_runner.set_block_size(16) model_runner.set_block_size(16)
prompt_lens = [] prompt_lens = []
...@@ -205,14 +212,17 @@ def test_empty_seq_group(): ...@@ -205,14 +212,17 @@ def test_empty_seq_group():
"facebook/opt-125m", "facebook/opt-125m",
tokenizer_mode="auto", tokenizer_mode="auto",
trust_remote_code=False, trust_remote_code=False,
download_dir=None,
load_format="dummy",
seed=0, seed=0,
dtype="float16", dtype="float16",
revision=None, revision=None,
enforce_eager=False, enforce_eager=False,
) )
model_runner = ModelRunner(model_config, None, None, None, None) model_runner = ModelRunner(model_config=model_config,
parallel_config=None,
scheduler_config=None,
device_config=None,
load_config=None,
lora_config=None)
model_runner.set_block_size(16) model_runner.set_block_size(16)
seq_group_metadata_list = [] seq_group_metadata_list = []
input_tokens, input_positions, attn_metadata, _, _, _, slot_mapping = ( input_tokens, input_positions, attn_metadata, _, _, _, slot_mapping = (
...@@ -251,8 +261,6 @@ def test_hybrid_batches(batch_size, enforce_eager, monkeypatch): ...@@ -251,8 +261,6 @@ def test_hybrid_batches(batch_size, enforce_eager, monkeypatch):
"facebook/opt-125m", "facebook/opt-125m",
tokenizer_mode="auto", tokenizer_mode="auto",
trust_remote_code=False, trust_remote_code=False,
download_dir=None,
load_format="dummy",
seed=0, seed=0,
dtype="float16", dtype="float16",
revision=None, revision=None,
...@@ -262,11 +270,12 @@ def test_hybrid_batches(batch_size, enforce_eager, monkeypatch): ...@@ -262,11 +270,12 @@ def test_hybrid_batches(batch_size, enforce_eager, monkeypatch):
100000, 100000,
100000, 100000,
enable_chunked_prefill=True) enable_chunked_prefill=True)
model_runner = ModelRunner(model_config, model_runner = ModelRunner(model_config=model_config,
None, parallel_config=None,
scheduler_config, scheduler_config=scheduler_config,
None, device_config=None,
None, load_config=None,
lora_config=None,
is_driver_worker=True) is_driver_worker=True)
model_runner.set_block_size(16) model_runner.set_block_size(16)
......
...@@ -23,6 +23,7 @@ def test_swap() -> None: ...@@ -23,6 +23,7 @@ def test_swap() -> None:
scheduler_config=engine_config.scheduler_config, scheduler_config=engine_config.scheduler_config,
device_config=engine_config.device_config, device_config=engine_config.device_config,
cache_config=engine_config.cache_config, cache_config=engine_config.cache_config,
load_config=engine_config.load_config,
local_rank=0, local_rank=0,
rank=0, rank=0,
distributed_init_method=distributed_init_method, distributed_init_method=distributed_init_method,
......
import enum import enum
import io
import json import json
import os import os
import typing from dataclasses import dataclass, field, fields
from dataclasses import dataclass, fields
from typing import TYPE_CHECKING, ClassVar, List, Optional, Union from typing import TYPE_CHECKING, ClassVar, List, Optional, Union
import torch import torch
...@@ -18,10 +16,14 @@ from vllm.utils import (get_cpu_memory, get_nvcc_cuda_version, is_cpu, is_hip, ...@@ -18,10 +16,14 @@ from vllm.utils import (get_cpu_memory, get_nvcc_cuda_version, is_cpu, is_hip,
if TYPE_CHECKING: if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup from ray.util.placement_group import PlacementGroup
from vllm.model_executor.tensorizer_loader import TensorizerArgs from vllm.model_executor.model_loader.loader import BaseModelLoader
logger = init_logger(__name__) logger = init_logger(__name__)
# If true, will load models from ModelScope instead of Hugging Face Hub.
VLLM_USE_MODELSCOPE = os.environ.get("VLLM_USE_MODELSCOPE",
"False").lower() == "true"
_GB = 1 << 30 _GB = 1 << 30
...@@ -35,18 +37,6 @@ class ModelConfig: ...@@ -35,18 +37,6 @@ class ModelConfig:
available, and "slow" will always use the slow tokenizer. available, and "slow" will always use the slow tokenizer.
trust_remote_code: Trust remote code (e.g., from HuggingFace) when trust_remote_code: Trust remote code (e.g., from HuggingFace) when
downloading the model and tokenizer. downloading the model and tokenizer.
download_dir: Directory to download and load the weights, default to the
default cache directory of huggingface.
load_format: The format of the model weights to load:
"auto" will try to load the weights in the safetensors format and
fall back to the pytorch bin format if safetensors format is
not available.
"pt" will load the weights in the pytorch bin format.
"safetensors" will load the weights in the safetensors format.
"npcache" will load the weights in pytorch format and store
a numpy cache to speed up the loading.
"dummy" will initialize the weights with random values, which is
mainly for profiling.
dtype: Data type for model weights and activations. The "auto" option dtype: Data type for model weights and activations. The "auto" option
will use FP16 precision for FP32 and FP16 models, and BF16 precision will use FP16 precision for FP32 and FP16 models, and BF16 precision
for BF16 models. for BF16 models.
...@@ -83,8 +73,6 @@ class ModelConfig: ...@@ -83,8 +73,6 @@ class ModelConfig:
tokenizer: str, tokenizer: str,
tokenizer_mode: str, tokenizer_mode: str,
trust_remote_code: bool, trust_remote_code: bool,
download_dir: Optional[str],
load_format: str,
dtype: Union[str, torch.dtype], dtype: Union[str, torch.dtype],
seed: int, seed: int,
revision: Optional[str] = None, revision: Optional[str] = None,
...@@ -101,8 +89,6 @@ class ModelConfig: ...@@ -101,8 +89,6 @@ class ModelConfig:
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.tokenizer_mode = tokenizer_mode self.tokenizer_mode = tokenizer_mode
self.trust_remote_code = trust_remote_code self.trust_remote_code = trust_remote_code
self.download_dir = download_dir
self.load_format = load_format
self.seed = seed self.seed = seed
self.revision = revision self.revision = revision
self.code_revision = code_revision self.code_revision = code_revision
...@@ -113,64 +99,16 @@ class ModelConfig: ...@@ -113,64 +99,16 @@ class ModelConfig:
self.max_context_len_to_capture = max_context_len_to_capture self.max_context_len_to_capture = max_context_len_to_capture
self.max_logprobs = max_logprobs self.max_logprobs = max_logprobs
if os.environ.get("VLLM_USE_MODELSCOPE", "False").lower() == "true":
# download model from ModelScope hub,
# lazy import so that modelscope is not required for normal use.
# pylint: disable=C.
from modelscope.hub.snapshot_download import snapshot_download
if not os.path.exists(model):
model_path = snapshot_download(model_id=model,
cache_dir=download_dir,
revision=revision)
else:
model_path = model
self.model = model_path
self.download_dir = model_path
self.tokenizer = model_path
self.hf_config = get_config(self.model, trust_remote_code, revision, self.hf_config = get_config(self.model, trust_remote_code, revision,
code_revision) code_revision)
self.hf_text_config = get_hf_text_config(self.hf_config) self.hf_text_config = get_hf_text_config(self.hf_config)
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
self.max_model_len = _get_and_verify_max_len(self.hf_text_config, self.max_model_len = _get_and_verify_max_len(self.hf_text_config,
max_model_len) max_model_len)
self._verify_load_format()
self._verify_tokenizer_mode() self._verify_tokenizer_mode()
self._verify_quantization() self._verify_quantization()
self._verify_cuda_graph() self._verify_cuda_graph()
def _verify_load_format(self) -> None:
load_format = self.load_format.lower()
supported_load_format = [
"auto", "pt", "safetensors", "npcache", "dummy", "tensorizer"
]
rocm_not_supported_load_format: List[str] = []
if load_format not in supported_load_format:
raise ValueError(
f"Unknown load format: {self.load_format}. Must be one of "
"'auto', 'pt', 'safetensors', 'npcache', 'tensorizer', or "
"'dummy'.")
if is_hip() and load_format in rocm_not_supported_load_format:
rocm_supported_load_format = [
f for f in supported_load_format
if (f not in rocm_not_supported_load_format)
]
raise ValueError(
f"load format '{load_format}' is not supported in ROCm. "
f"Supported load format are "
f"{rocm_supported_load_format}")
# TODO: Remove this check once HF updates the pt weights of Mixtral.
architectures = getattr(self.hf_config, "architectures", [])
# architectures can be None instead of []
if architectures and "MixtralForCausalLM" in architectures \
and load_format == "pt":
raise ValueError(
"Currently, the 'pt' format is not supported for Mixtral. "
"Please use the 'safetensors' format instead. ")
self.load_format = load_format
def _verify_tokenizer_mode(self) -> None: def _verify_tokenizer_mode(self) -> None:
tokenizer_mode = self.tokenizer_mode.lower() tokenizer_mode = self.tokenizer_mode.lower()
if tokenizer_mode not in ["auto", "slow"]: if tokenizer_mode not in ["auto", "slow"]:
...@@ -471,6 +409,65 @@ class TokenizerPoolConfig: ...@@ -471,6 +409,65 @@ class TokenizerPoolConfig:
return tokenizer_pool_config return tokenizer_pool_config
class LoadFormat(str, enum.Enum):
AUTO = "auto"
PT = "pt"
SAFETENSORS = "safetensors"
NPCACHE = "npcache"
DUMMY = "dummy"
TENSORIZER = "tensorizer"
@dataclass
class LoadConfig:
"""
download_dir: Directory to download and load the weights, default to the
default cache directory of huggingface.
load_format: The format of the model weights to load:
"auto" will try to load the weights in the safetensors format and
fall back to the pytorch bin format if safetensors format is
not available.
"pt" will load the weights in the pytorch bin format.
"safetensors" will load the weights in the safetensors format.
"npcache" will load the weights in pytorch format and store
a numpy cache to speed up the loading.
"dummy" will initialize the weights with random values, which is
mainly for profiling.
"tensorizer" will use CoreWeave's tensorizer library for
fast weight loading.
"""
load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO
download_dir: Optional[str] = None
model_loader_extra_config: Optional[Union[str, dict]] = field(
default_factory=dict)
def __post_init__(self):
model_loader_extra_config = self.model_loader_extra_config or {}
if isinstance(model_loader_extra_config, str):
self.model_loader_extra_config = json.loads(
model_loader_extra_config)
self._verify_load_format()
def _verify_load_format(self) -> None:
if not isinstance(self.load_format, str):
return
load_format = self.load_format.lower()
self.load_format = LoadFormat(load_format)
rocm_not_supported_load_format: List[str] = []
if is_hip() and load_format in rocm_not_supported_load_format:
rocm_supported_load_format = [
f for f in LoadFormat.__members__
if (f not in rocm_not_supported_load_format)
]
raise ValueError(
f"load format '{load_format}' is not supported in ROCm. "
f"Supported load formats are "
f"{rocm_supported_load_format}")
class ParallelConfig: class ParallelConfig:
"""Configuration for the distributed execution. """Configuration for the distributed execution.
...@@ -699,8 +696,6 @@ class SpeculativeConfig: ...@@ -699,8 +696,6 @@ class SpeculativeConfig:
tokenizer=target_model_config.tokenizer, tokenizer=target_model_config.tokenizer,
tokenizer_mode=target_model_config.tokenizer_mode, tokenizer_mode=target_model_config.tokenizer_mode,
trust_remote_code=target_model_config.trust_remote_code, trust_remote_code=target_model_config.trust_remote_code,
download_dir=target_model_config.download_dir,
load_format=target_model_config.load_format,
dtype=target_model_config.dtype, dtype=target_model_config.dtype,
seed=target_model_config.seed, seed=target_model_config.seed,
revision=draft_revision, revision=draft_revision,
...@@ -887,65 +882,6 @@ class VisionLanguageConfig: ...@@ -887,65 +882,6 @@ class VisionLanguageConfig:
f"{[x.name for x in cls.ImageInputType]}.") from e f"{[x.name for x in cls.ImageInputType]}.") from e
@dataclass
class TensorizerConfig:
tensorizer_uri: Union[io.BufferedIOBase, io.RawIOBase, typing.BinaryIO,
str, bytes, os.PathLike, int]
vllm_tensorized: bool
verify_hash: Optional[bool] = False
num_readers: Optional[int] = 1
encryption_keyfile: Optional[str] = None
s3_access_key_id: Optional[str] = None
s3_secret_access_key: Optional[str] = None
s3_endpoint: Optional[str] = None
model_class: Optional[torch.nn.Module] = None
hf_config: Optional[PretrainedConfig] = None
dtype: Union[str, torch.dtype] = None
def _construct_tensorizer_args(self) -> "TensorizerArgs":
from vllm.model_executor.tensorizer_loader import TensorizerArgs
tensorizer_args = {
"tensorizer_uri": self.tensorizer_uri,
"vllm_tensorized": self.vllm_tensorized,
"verify_hash": self.verify_hash,
"num_readers": self.num_readers,
"encryption_keyfile": self.encryption_keyfile,
"s3_access_key_id": self.s3_access_key_id,
"s3_secret_access_key": self.s3_secret_access_key,
"s3_endpoint": self.s3_endpoint,
}
return TensorizerArgs(**tensorizer_args)
def verify_with_parallel_config(
self,
parallel_config: "ParallelConfig",
) -> None:
if (parallel_config.tensor_parallel_size > 1
and self.tensorizer_uri is not None):
raise ValueError(
"Loading to multiple GPUs is not currently supported with "
"vLLM-serialized models. Please set tensor_parallel_size=1."
" or use a non-vLLM-serialized model, such as a "
"serialized Hugging Face `PretrainedModel`.")
def verify_with_model_config(self, model_config) -> None:
if (model_config.quantization is not None
and self.tensorizer_uri is not None):
from vllm.model_executor.tensorizer_loader import (
tensorizer_warning)
tensorizer_warning(
"Loading a model using Tensorizer with quantization on vLLM"
" is unstable and may lead to errors.")
if (model_config.load_format != "tensorizer"
and self.tensorizer_uri is not None):
raise ValueError(
"A tensorizer uri was passed for tensorizer loading, but the "
f"load format was set to {model_config.load_format}. "
"Please set the load format to 'tensorizer' to use "
f"tensorizer args.")
_STR_DTYPE_TO_TORCH_DTYPE = { _STR_DTYPE_TO_TORCH_DTYPE = {
"half": torch.float16, "half": torch.float16,
"float16": torch.float16, "float16": torch.float16,
...@@ -1105,11 +1041,11 @@ class EngineConfig: ...@@ -1105,11 +1041,11 @@ class EngineConfig:
parallel_config: ParallelConfig parallel_config: ParallelConfig
scheduler_config: SchedulerConfig scheduler_config: SchedulerConfig
device_config: DeviceConfig device_config: DeviceConfig
load_config: LoadConfig
lora_config: Optional[LoRAConfig] lora_config: Optional[LoRAConfig]
vision_language_config: Optional[VisionLanguageConfig] vision_language_config: Optional[VisionLanguageConfig]
speculative_config: Optional[SpeculativeConfig] speculative_config: Optional[SpeculativeConfig]
decoding_config: Optional[DecodingConfig] decoding_config: Optional[DecodingConfig]
tensorizer_config: Optional[TensorizerConfig]
def __post_init__(self): def __post_init__(self):
"""Verify configs are valid & consistent with each other. """Verify configs are valid & consistent with each other.
...@@ -1117,11 +1053,6 @@ class EngineConfig: ...@@ -1117,11 +1053,6 @@ class EngineConfig:
self.model_config.verify_with_parallel_config(self.parallel_config) self.model_config.verify_with_parallel_config(self.parallel_config)
self.cache_config.verify_with_parallel_config(self.parallel_config) self.cache_config.verify_with_parallel_config(self.parallel_config)
if self.tensorizer_config:
self.tensorizer_config.verify_with_parallel_config(
self.parallel_config)
self.tensorizer_config.verify_with_model_config(self.model_config)
if self.lora_config: if self.lora_config:
self.lora_config.verify_with_model_config(self.model_config) self.lora_config.verify_with_model_config(self.model_config)
self.lora_config.verify_with_scheduler_config( self.lora_config.verify_with_scheduler_config(
......
import argparse import argparse
import dataclasses import dataclasses
import io
import os
from dataclasses import dataclass from dataclasses import dataclass
from typing import BinaryIO, Optional, Union from typing import Optional
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
EngineConfig, LoRAConfig, ModelConfig, ParallelConfig, EngineConfig, LoadConfig, LoRAConfig, ModelConfig,
SchedulerConfig, SpeculativeConfig, TensorizerConfig, ParallelConfig, SchedulerConfig, SpeculativeConfig,
TokenizerPoolConfig, VisionLanguageConfig) TokenizerPoolConfig, VisionLanguageConfig)
from vllm.model_executor.tensorizer_loader import TensorizerArgs
from vllm.utils import str_to_int_tuple from vllm.utils import str_to_int_tuple
...@@ -60,17 +57,7 @@ class EngineArgs: ...@@ -60,17 +57,7 @@ class EngineArgs:
ray_workers_use_nsight: bool = False ray_workers_use_nsight: bool = False
num_gpu_blocks_override: Optional[int] = None num_gpu_blocks_override: Optional[int] = None
num_lookahead_slots: int = 0 num_lookahead_slots: int = 0
model_loader_extra_config: Optional[dict] = None
# Tensorizer configuration parameters
tensorizer_uri: Union[io.BufferedIOBase, io.RawIOBase, BinaryIO, str,
bytes, os.PathLike, int] = None
vllm_tensorized: bool = False
verify_hash: Optional[bool] = False
num_readers: Optional[int] = 1
encryption_keyfile: Optional[str] = None
s3_access_key_id: Optional[str] = None
s3_secret_access_key: Optional[str] = None
s3_endpoint: Optional[str] = None
# Related to Vision-language models such as llava # Related to Vision-language models such as llava
image_input_type: Optional[str] = None image_input_type: Optional[str] = None
...@@ -429,7 +416,16 @@ class EngineArgs: ...@@ -429,7 +416,16 @@ class EngineArgs:
default=None, default=None,
help='The number of speculative tokens to sample from ' help='The number of speculative tokens to sample from '
'the draft model in speculative decoding') 'the draft model in speculative decoding')
parser = TensorizerArgs.add_cli_args(parser)
parser.add_argument('--model-loader-extra-config',
type=str,
default=EngineArgs.model_loader_extra_config,
help='Extra config for model loader. '
'This will be passed to the model loader '
'corresponding to the chosen load_format. '
'This should be a JSON string that will be '
'parsed into a dictionary.')
return parser return parser
@classmethod @classmethod
...@@ -444,11 +440,11 @@ class EngineArgs: ...@@ -444,11 +440,11 @@ class EngineArgs:
device_config = DeviceConfig(self.device) device_config = DeviceConfig(self.device)
model_config = ModelConfig( model_config = ModelConfig(
self.model, self.tokenizer, self.tokenizer_mode, self.model, self.tokenizer, self.tokenizer_mode,
self.trust_remote_code, self.download_dir, self.load_format, self.trust_remote_code, self.dtype, self.seed, self.revision,
self.dtype, self.seed, self.revision, self.code_revision, self.code_revision, self.tokenizer_revision, self.max_model_len,
self.tokenizer_revision, self.max_model_len, self.quantization, self.quantization, self.quantization_param_path,
self.quantization_param_path, self.enforce_eager, self.enforce_eager, self.max_context_len_to_capture,
self.max_context_len_to_capture, self.max_logprobs) self.max_logprobs)
cache_config = CacheConfig(self.block_size, cache_config = CacheConfig(self.block_size,
self.gpu_memory_utilization, self.gpu_memory_utilization,
self.swap_space, self.kv_cache_dtype, self.swap_space, self.kv_cache_dtype,
...@@ -492,15 +488,10 @@ class EngineArgs: ...@@ -492,15 +488,10 @@ class EngineArgs:
max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
and self.max_cpu_loras > 0 else None) if self.enable_lora else None and self.max_cpu_loras > 0 else None) if self.enable_lora else None
tensorizer_config = TensorizerConfig( load_config = LoadConfig(
tensorizer_uri=self.tensorizer_uri, load_format=self.load_format,
vllm_tensorized=self.vllm_tensorized, download_dir=self.download_dir,
verify_hash=self.verify_hash, model_loader_extra_config=self.model_loader_extra_config,
num_readers=self.num_readers,
encryption_keyfile=self.encryption_keyfile,
s3_access_key_id=self.s3_access_key_id,
s3_secret_access_key=self.s3_secret_access_key,
s3_endpoint=self.s3_endpoint,
) )
if self.image_input_type: if self.image_input_type:
...@@ -530,8 +521,8 @@ class EngineArgs: ...@@ -530,8 +521,8 @@ class EngineArgs:
lora_config=lora_config, lora_config=lora_config,
vision_language_config=vision_language_config, vision_language_config=vision_language_config,
speculative_config=speculative_config, speculative_config=speculative_config,
decoding_config=decoding_config, load_config=load_config,
tensorizer_config=tensorizer_config) decoding_config=decoding_config)
@dataclass @dataclass
......
...@@ -4,9 +4,9 @@ from typing import Iterable, List, Optional, Tuple, Type, Union ...@@ -4,9 +4,9 @@ from typing import Iterable, List, Optional, Tuple, Type, Union
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
import vllm import vllm
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoRAConfig, from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoadConfig,
ModelConfig, ParallelConfig, SchedulerConfig, LoRAConfig, ModelConfig, ParallelConfig,
SpeculativeConfig, TensorizerConfig, SchedulerConfig, SpeculativeConfig,
VisionLanguageConfig) VisionLanguageConfig)
from vllm.core.scheduler import Scheduler, SchedulerOutputs from vllm.core.scheduler import Scheduler, SchedulerOutputs
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
...@@ -72,11 +72,11 @@ class LLMEngine: ...@@ -72,11 +72,11 @@ class LLMEngine:
parallel_config: ParallelConfig, parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig, scheduler_config: SchedulerConfig,
device_config: DeviceConfig, device_config: DeviceConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig], vision_language_config: Optional[VisionLanguageConfig],
speculative_config: Optional[SpeculativeConfig], speculative_config: Optional[SpeculativeConfig],
decoding_config: Optional[DecodingConfig], decoding_config: Optional[DecodingConfig],
tensorizer_config: Optional[TensorizerConfig],
executor_class: Type[ExecutorBase], executor_class: Type[ExecutorBase],
log_stats: bool, log_stats: bool,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
...@@ -92,8 +92,8 @@ class LLMEngine: ...@@ -92,8 +92,8 @@ class LLMEngine:
f"trust_remote_code={model_config.trust_remote_code}, " f"trust_remote_code={model_config.trust_remote_code}, "
f"dtype={model_config.dtype}, " f"dtype={model_config.dtype}, "
f"max_seq_len={model_config.max_model_len}, " f"max_seq_len={model_config.max_model_len}, "
f"download_dir={model_config.download_dir!r}, " f"download_dir={load_config.download_dir!r}, "
f"load_format={model_config.load_format}, " f"load_format={load_config.load_format}, "
f"tensor_parallel_size={parallel_config.tensor_parallel_size}, " f"tensor_parallel_size={parallel_config.tensor_parallel_size}, "
f"disable_custom_all_reduce=" f"disable_custom_all_reduce="
f"{parallel_config.disable_custom_all_reduce}, " f"{parallel_config.disable_custom_all_reduce}, "
...@@ -114,8 +114,8 @@ class LLMEngine: ...@@ -114,8 +114,8 @@ class LLMEngine:
self.scheduler_config = scheduler_config self.scheduler_config = scheduler_config
self.device_config = device_config self.device_config = device_config
self.speculative_config = speculative_config self.speculative_config = speculative_config
self.load_config = load_config
self.decoding_config = decoding_config or DecodingConfig() self.decoding_config = decoding_config or DecodingConfig()
self.tensorizer_config = tensorizer_config
self.log_stats = log_stats self.log_stats = log_stats
self._init_tokenizer() self._init_tokenizer()
...@@ -131,7 +131,7 @@ class LLMEngine: ...@@ -131,7 +131,7 @@ class LLMEngine:
lora_config=lora_config, lora_config=lora_config,
vision_language_config=vision_language_config, vision_language_config=vision_language_config,
speculative_config=speculative_config, speculative_config=speculative_config,
tensorizer_config=tensorizer_config, load_config=load_config,
) )
self._initialize_kv_caches() self._initialize_kv_caches()
...@@ -271,9 +271,6 @@ class LLMEngine: ...@@ -271,9 +271,6 @@ class LLMEngine:
def _verify_args(self) -> None: def _verify_args(self) -> None:
self.model_config.verify_with_parallel_config(self.parallel_config) self.model_config.verify_with_parallel_config(self.parallel_config)
self.cache_config.verify_with_parallel_config(self.parallel_config) self.cache_config.verify_with_parallel_config(self.parallel_config)
if self.tensorizer_config:
self.tensorizer_config.verify_with_parallel_config(
self.parallel_config)
if self.lora_config: if self.lora_config:
self.lora_config.verify_with_model_config(self.model_config) self.lora_config.verify_with_model_config(self.model_config)
self.lora_config.verify_with_scheduler_config( self.lora_config.verify_with_scheduler_config(
......
...@@ -40,6 +40,7 @@ class CPUExecutor(ExecutorBase): ...@@ -40,6 +40,7 @@ class CPUExecutor(ExecutorBase):
scheduler_config=self.scheduler_config, scheduler_config=self.scheduler_config,
device_config=self.device_config, device_config=self.device_config,
cache_config=self.cache_config, cache_config=self.cache_config,
load_config=self.load_config,
local_rank=0, local_rank=0,
rank=0, rank=0,
distributed_init_method=distributed_init_method, distributed_init_method=distributed_init_method,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment