Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
69e1d2fb
Unverified
Commit
69e1d2fb
authored
Apr 16, 2024
by
Antoni Baum
Committed by
GitHub
Apr 16, 2024
Browse files
[Core] Refactor model loading code (#4097)
parent
05434764
Changes
67
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
244 additions
and
279 deletions
+244
-279
.buildkite/test-pipeline.yaml
.buildkite/test-pipeline.yaml
+1
-1
examples/fp8/extract_scales.py
examples/fp8/extract_scales.py
+2
-2
examples/tensorize_vllm_model.py
examples/tensorize_vllm_model.py
+1
-1
tests/lora/conftest.py
tests/lora/conftest.py
+5
-5
tests/lora/test_worker.py
tests/lora/test_worker.py
+6
-4
tests/model_executor/weight_utils.py
tests/model_executor/weight_utils.py
+1
-1
tests/quantization/test_autogptq_marlin_configs.py
tests/quantization/test_autogptq_marlin_configs.py
+0
-4
tests/samplers/test_sampler.py
tests/samplers/test_sampler.py
+12
-2
tests/spec_decode/utils.py
tests/spec_decode/utils.py
+1
-0
tests/tensorizer_loader/__init__.py
tests/tensorizer_loader/__init__.py
+0
-0
tests/tensorizer_loader/tensorize_vllm_model_for_testing.py
tests/tensorizer_loader/tensorize_vllm_model_for_testing.py
+2
-2
tests/tensorizer_loader/test_tensorizer.py
tests/tensorizer_loader/test_tensorizer.py
+82
-57
tests/test_config.py
tests/test_config.py
+0
-4
tests/test_logits_processor.py
tests/test_logits_processor.py
+6
-1
tests/worker/test_model_runner.py
tests/worker/test_model_runner.py
+24
-15
tests/worker/test_swap.py
tests/worker/test_swap.py
+1
-0
vllm/config.py
vllm/config.py
+66
-135
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+25
-34
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+8
-11
vllm/executor/cpu_executor.py
vllm/executor/cpu_executor.py
+1
-0
No files found.
.buildkite/test-pipeline.yaml
View file @
69e1d2fb
...
...
@@ -92,7 +92,7 @@ steps:
parallelism
:
4
-
label
:
Tensorizer Test
command
:
apt-get install curl libsodium23 && pytest -v -s tensorizer
command
:
apt-get install curl libsodium23 && pytest -v -s tensorizer
_loader
-
label
:
Metrics Test
command
:
pytest -v -s metrics
...
...
examples/fp8/extract_scales.py
View file @
69e1d2fb
...
...
@@ -11,7 +11,7 @@ from safetensors.torch import safe_open
from
vllm.model_executor.layers.quantization.schema
import
QuantParamSchema
# Adapted from vllm/model_executor/weight_utils.py
# Adapted from vllm/model_executor/
model_loader/
weight_utils.py
# The main differences are that we add the NPZ format and simplify
# its functionality drastically for our purposes (e.g. we assume that
# the quantized model exists locally and there is no need to download it)
...
...
@@ -71,7 +71,7 @@ def _prepare_hf_weights(
return
hf_weights_files
,
use_safetensors
# Adapted from vllm/model_executor/weight_utils.py
# Adapted from vllm/model_executor/
model_loader/
weight_utils.py
def
_hf_tensorfile_iterator
(
filename
:
str
,
load_format
:
str
,
use_safetensors
:
bool
):
if
load_format
==
"npz"
:
...
...
examples/tensorize_vllm_model.py
View file @
69e1d2fb
...
...
@@ -16,8 +16,8 @@ from transformers import AutoConfig, PretrainedConfig
from
vllm.distributed
import
initialize_model_parallel
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.engine.llm_engine
import
LLMEngine
from
vllm.model_executor.model_loader.tensorizer
import
TensorizerArgs
from
vllm.model_executor.models
import
ModelRegistry
from
vllm.model_executor.tensorizer_loader
import
TensorizerArgs
# yapf conflicts with isort for this docstring
# yapf: disable
...
...
tests/lora/conftest.py
View file @
69e1d2fb
...
...
@@ -153,11 +153,11 @@ def llama_2_7b_engine_extra_embeddings() -> nn.Module:
cleanup
()
get_model_old
=
get_model
def
get_model_patched
(
model_config
,
device_config
,
**
kwargs
):
return
get_model_old
(
model_config
,
device
_config
,
lora
_config
=
LoRAConfig
(
max_loras
=
4
,
max_lora_rank
=
8
)
)
def
get_model_patched
(
*
,
model_config
,
device_config
,
**
kwargs
):
kwargs
[
"lora_config"
]
=
LoRAConfig
(
max_loras
=
4
,
max_lora_rank
=
8
)
return
get_model_old
(
model_config
=
model
_config
,
device
_config
=
device_config
,
**
kwargs
)
with
patch
(
"vllm.worker.model_runner.get_model"
,
get_model_patched
):
engine
=
vllm
.
LLM
(
"meta-llama/Llama-2-7b-hf"
,
enable_lora
=
False
)
...
...
tests/lora/test_worker.py
View file @
69e1d2fb
...
...
@@ -3,8 +3,8 @@ import random
import
tempfile
from
unittest.mock
import
patch
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
Lo
RA
Config
,
Model
Config
,
ParallelConfig
,
SchedulerConfig
)
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
Lo
ad
Config
,
LoRA
Config
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
)
from
vllm.lora.models
import
LoRAMapping
from
vllm.lora.request
import
LoRARequest
from
vllm.worker.worker
import
Worker
...
...
@@ -18,12 +18,14 @@ def test_worker_apply_lora(sql_lora_files):
"meta-llama/Llama-2-7b-hf"
,
tokenizer_mode
=
"auto"
,
trust_remote_code
=
False
,
download_dir
=
None
,
load_format
=
"dummy"
,
seed
=
0
,
dtype
=
"float16"
,
revision
=
None
,
),
load_config
=
LoadConfig
(
download_dir
=
None
,
load_format
=
"dummy"
,
),
parallel_config
=
ParallelConfig
(
1
,
1
,
False
),
scheduler_config
=
SchedulerConfig
(
32
,
32
,
32
),
device_config
=
DeviceConfig
(
"cuda"
),
...
...
tests/model_executor/weight_utils.py
View file @
69e1d2fb
...
...
@@ -3,7 +3,7 @@ import os
import
huggingface_hub.constants
import
pytest
from
vllm.model_executor.weight_utils
import
enable_hf_transfer
from
vllm.model_executor.
model_loader.
weight_utils
import
enable_hf_transfer
def
test_hf_transfer_auto_activation
():
...
...
tests/quantization/test_autogptq_marlin_configs.py
View file @
69e1d2fb
...
...
@@ -36,8 +36,6 @@ def test_auto_gptq(model_quant_type: str, ) -> None:
model_path
,
tokenizer_mode
=
"auto"
,
trust_remote_code
=
False
,
download_dir
=
None
,
load_format
=
"dummy"
,
seed
=
0
,
dtype
=
"float16"
,
revision
=
None
,
...
...
@@ -49,8 +47,6 @@ def test_auto_gptq(model_quant_type: str, ) -> None:
model_path
,
tokenizer_mode
=
"auto"
,
trust_remote_code
=
False
,
download_dir
=
None
,
load_format
=
"dummy"
,
seed
=
0
,
dtype
=
"float16"
,
revision
=
None
,
...
...
tests/samplers/test_sampler.py
View file @
69e1d2fb
...
...
@@ -32,7 +32,12 @@ def _prepare_test(
1e-2
,
dtype
=
input_tensor
.
dtype
)
sampler
=
MockLogitsSampler
(
fake_logits
)
model_runner
=
ModelRunner
(
None
,
None
,
None
,
None
,
None
)
model_runner
=
ModelRunner
(
model_config
=
None
,
parallel_config
=
None
,
scheduler_config
=
None
,
device_config
=
None
,
load_config
=
None
,
lora_config
=
None
)
return
input_tensor
,
fake_logits
,
sampler
,
model_runner
...
...
@@ -591,7 +596,12 @@ def test_sampler_top_k_top_p(seed: int, device: str):
device
=
input_tensor
.
device
,
dtype
=
input_tensor
.
dtype
)
sampler
=
MockLogitsSampler
(
fake_logits
)
model_runner
=
ModelRunner
(
None
,
None
,
None
,
None
,
None
)
model_runner
=
ModelRunner
(
model_config
=
None
,
parallel_config
=
None
,
scheduler_config
=
None
,
device_config
=
None
,
load_config
=
None
,
lora_config
=
None
)
generation_model
=
GenerationMixin
()
generation_config
=
GenerationConfig
(
top_k
=
top_k
,
...
...
tests/spec_decode/utils.py
View file @
69e1d2fb
...
...
@@ -118,6 +118,7 @@ def create_worker(cls: type,
scheduler_config
=
engine_config
.
scheduler_config
,
device_config
=
engine_config
.
device_config
,
cache_config
=
engine_config
.
cache_config
,
load_config
=
engine_config
.
load_config
,
local_rank
=
0
,
rank
=
0
,
distributed_init_method
=
distributed_init_method
,
...
...
tests/tensorizer/__init__.py
→
tests/tensorizer
_loader
/__init__.py
View file @
69e1d2fb
File moved
tests/tensorizer/tensorize_vllm_model_for_testing.py
→
tests/tensorizer
_loader
/tensorize_vllm_model_for_testing.py
View file @
69e1d2fb
...
...
@@ -16,8 +16,8 @@ from transformers import AutoConfig, PretrainedConfig
from
vllm.distributed
import
initialize_model_parallel
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.engine.llm_engine
import
LLMEngine
from
vllm.model_executor.model_loader.tensorizer
import
TensorizerArgs
from
vllm.model_executor.models
import
ModelRegistry
from
vllm.model_executor.tensorizer_loader
import
TensorizerArgs
# yapf conflicts with isort for this docstring
# yapf: disable
...
...
@@ -74,7 +74,7 @@ def parse_args():
"extremely quickly. Tensor encryption and decryption is "
"also supported, although libsodium must be installed to "
"use it."
)
parser
=
EngineArgs
.
add_cli_args
(
parser
)
parser
=
TensorizerArgs
.
add_cli_args
(
EngineArgs
.
add_cli_args
(
parser
)
)
subparsers
=
parser
.
add_subparsers
(
dest
=
'command'
)
serialize_parser
=
subparsers
.
add_parser
(
...
...
tests/tensorizer/test_tensorizer.py
→
tests/tensorizer
_loader
/test_tensorizer.py
View file @
69e1d2fb
import
gc
import
json
import
os
import
subprocess
from
unittest.mock
import
MagicMock
,
patch
import
openai
import
pytest
import
ray
import
torch
from
tests.entrypoints.test_openai_server
import
ServerRunner
from
vllm
import
SamplingParams
from
vllm.config
import
TensorizerConfig
from
vllm.model_executor.tensorizer_loader
import
(
EncryptionParams
,
TensorSerializer
,
is_vllm_serialized_tensorizer
,
load_with_tensorizer
,
open_stream
)
from
vllm.model_executor.model_loader.tensorizer
import
(
EncryptionParams
,
TensorizerConfig
,
TensorSerializer
,
is_vllm_serialized_tensorizer
,
load_with_tensorizer
,
open_stream
)
prompts
=
[
"Hello, my name is"
,
...
...
@@ -22,6 +25,8 @@ prompts = [
sampling_params
=
SamplingParams
(
temperature
=
0.8
,
top_p
=
0.95
,
seed
=
0
)
model_ref
=
"facebook/opt-125m"
tensorize_model_for_testing_script
=
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
"tensorize_vllm_model_for_testing.py"
)
def
is_curl_installed
():
...
...
@@ -38,7 +43,7 @@ def tensorizer_config():
return
config
@
patch
(
'vllm.model_executor.tensorizer
_loader
.TensorizerAgent'
)
@
patch
(
'vllm.model_executor.
model_loader.
tensorizer.TensorizerAgent'
)
def
test_load_with_tensorizer
(
mock_agent
,
tensorizer_config
):
mock_linear_method
=
MagicMock
()
mock_agent_instance
=
mock_agent
.
return_value
...
...
@@ -81,11 +86,13 @@ def test_deserialized_vllm_model_has_same_outputs(vllm_runner, tmp_path):
del
vllm_model
,
model
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
loaded_vllm_model
=
vllm_runner
(
model_ref
,
loaded_vllm_model
=
vllm_runner
(
model_ref
,
load_format
=
"tensorizer"
,
tensorizer_uri
=
model_path
,
model_loader_extra_config
=
TensorizerConfig
(
tensorizer_uri
=
model_path
,
num_readers
=
1
,
vllm_tensorized
=
True
)
vllm_tensorized
=
True
),
)
deserialized_outputs
=
loaded_vllm_model
.
generate
(
prompts
,
sampling_params
)
# Assumes SamplingParams being seeded ensures the outputs are deterministic
...
...
@@ -97,14 +104,14 @@ def test_can_deserialize_s3(vllm_runner):
model_ref
=
"EleutherAI/pythia-1.4b"
tensorized_path
=
f
"s3://tensorized/
{
model_ref
}
/fp16/model.tensors"
loaded_hf_model
=
vllm_runner
(
model_ref
,
tensorizer_uri
=
tensorized_path
,
loaded_hf_model
=
vllm_runner
(
model_ref
,
load_format
=
"tensorizer"
,
model_loader_extra_config
=
TensorizerConfig
(
tensorizer_uri
=
tensorized_path
,
num_readers
=
1
,
vllm_tensorized
=
False
,
s3_endpoint
=
"object.ord1.coreweave.com"
,
)
)
)
deserialized_outputs
=
loaded_hf_model
.
generate
(
prompts
,
sampling_params
)
...
...
@@ -131,11 +138,12 @@ def test_deserialized_encrypted_vllm_model_has_same_outputs(
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
loaded_vllm_model
=
vllm_runner
(
model_ref
,
tensorizer_uri
=
model_path
,
load_format
=
"tensorizer"
,
model_loader_extra_config
=
TensorizerConfig
(
tensorizer_uri
=
model_path
,
encryption_keyfile
=
key_path
,
num_readers
=
1
,
vllm_tensorized
=
True
)
vllm_tensorized
=
True
)
)
deserialized_outputs
=
loaded_vllm_model
.
generate
(
prompts
,
sampling_params
)
...
...
@@ -156,10 +164,11 @@ def test_deserialized_hf_model_has_same_outputs(hf_runner, vllm_runner,
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
loaded_hf_model
=
vllm_runner
(
model_ref
,
tensorizer_uri
=
model_path
,
load_format
=
"tensorizer"
,
model_loader_extra_config
=
TensorizerConfig
(
tensorizer_uri
=
model_path
,
num_readers
=
1
,
vllm_tensorized
=
False
)
vllm_tensorized
=
False
)
)
deserialized_outputs
=
loaded_hf_model
.
generate_greedy
(
prompts
,
max_tokens
=
max_tokens
)
...
...
@@ -190,10 +199,12 @@ def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path):
torch
.
cuda
.
empty_cache
()
loaded_vllm_model
=
vllm_runner
(
model_ref
,
tensorizer_uri
=
model_path
,
load_format
=
"tensorizer"
,
model_loader_extra_config
=
TensorizerConfig
(
tensorizer_uri
=
model_path
,
num_readers
=
1
,
vllm_tensorized
=
True
,
),
enable_lora
=
True
,
max_loras
=
1
,
max_lora_rank
=
8
,
...
...
@@ -208,16 +219,18 @@ def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path):
def
test_load_without_tensorizer_load_format
(
vllm_runner
):
with
pytest
.
raises
(
ValueError
):
vllm_runner
(
model_ref
,
tensorizer_uri
=
"test"
)
vllm_runner
(
model_ref
,
model_loader_extra_config
=
TensorizerConfig
(
tensorizer_uri
=
"test"
,
vllm_tensorized
=
False
))
@
pytest
.
mark
.
skipif
(
not
is_curl_installed
(),
reason
=
"cURL is not installed"
)
def
test_tensorize_vllm_model
(
tmp_path
):
# Test serialize command
serialize_args
=
[
"python3"
,
"
tensorize
r/tensorize_vllm
_model_for_testing
.py"
,
"--model"
,
model_ref
,
"--dtype"
,
"float16"
,
"serialize"
,
"--serialized-directory"
,
tmp_path
,
"--suffix"
,
"tests"
"python3"
,
tensorize_model_for_testing
_script
,
"--model"
,
model_ref
,
"--dtype"
,
"float16"
,
"serialize"
,
"--serialized-directory"
,
tmp_path
,
"--suffix"
,
"tests"
]
result
=
subprocess
.
run
(
serialize_args
,
capture_output
=
True
,
text
=
True
)
print
(
result
.
stdout
)
# Print the output of the serialize command
...
...
@@ -229,8 +242,8 @@ def test_tensorize_vllm_model(tmp_path):
# Test deserialize command
deserialize_args
=
[
"python3"
,
"
tensorize
r/tensorize_vllm
_model_for_testing
.py"
,
"--model"
,
model_ref
,
"--dtype"
,
"float16"
,
"deserialize"
,
"--path-to-tensors"
,
"python3"
,
tensorize_model_for_testing
_script
,
"--model"
,
model_ref
,
"--dtype"
,
"float16"
,
"deserialize"
,
"--path-to-tensors"
,
path_to_tensors
]
result
=
subprocess
.
run
(
deserialize_args
,
capture_output
=
True
,
text
=
True
)
...
...
@@ -242,9 +255,9 @@ def test_tensorize_vllm_model(tmp_path):
def
test_openai_apiserver_with_tensorizer
(
tmp_path
):
## Serialize model
serialize_args
=
[
"python3"
,
"
tensorize
r/tensorize_vllm
_model_for_testing
.py"
,
"--model"
,
model_ref
,
"--dtype"
,
"float16"
,
"serialize"
,
"--serialized-directory"
,
tmp_path
,
"--suffix"
,
"tests"
"python3"
,
tensorize_model_for_testing
_script
,
"--model"
,
model_ref
,
"--dtype"
,
"float16"
,
"serialize"
,
"--serialized-directory"
,
tmp_path
,
"--suffix"
,
"tests"
]
result
=
subprocess
.
run
(
serialize_args
,
capture_output
=
True
,
text
=
True
)
print
(
result
.
stdout
)
# Print the output of the serialize command
...
...
@@ -253,25 +266,47 @@ def test_openai_apiserver_with_tensorizer(tmp_path):
f
"
\n
{
result
.
stdout
}
\n
{
result
.
stderr
}
"
)
path_to_tensors
=
f
"
{
tmp_path
}
/vllm/
{
model_ref
}
/tests/model.tensors"
model_loader_extra_config
=
{
"tensorizer_uri"
:
path_to_tensors
,
"vllm_tensorized"
:
True
}
## Start OpenAI API server
openai_args
=
[
"--model"
,
model_ref
,
"--dtype"
,
"float16"
,
"--load-format"
,
"tensorizer"
,
"--
tensorizer-uri"
,
path_to_tensors
,
"--vllm-tensorized
"
,
"--port"
,
"8000"
"tensorizer"
,
"--
model-loader-extra-config
"
,
json
.
dumps
(
model_loader_extra_config
),
"--port"
,
"8000"
]
server
=
ServerRunner
.
remote
(
openai_args
)
assert
ray
.
get
(
server
.
ready
.
remote
())
print
(
"Server ready."
)
assert
server
.
ready
.
remote
()
client
=
openai
.
OpenAI
(
base_url
=
"http://localhost:8000/v1"
,
api_key
=
"token-abc123"
,
)
completion
=
client
.
completions
.
create
(
model
=
model_ref
,
prompt
=
"Hello, my name is"
,
max_tokens
=
5
,
temperature
=
0.0
)
assert
completion
.
id
is
not
None
assert
completion
.
choices
is
not
None
and
len
(
completion
.
choices
)
==
1
assert
completion
.
choices
[
0
].
text
is
not
None
and
len
(
completion
.
choices
[
0
].
text
)
>=
5
assert
completion
.
choices
[
0
].
finish_reason
==
"length"
assert
completion
.
usage
==
openai
.
types
.
CompletionUsage
(
completion_tokens
=
5
,
prompt_tokens
=
6
,
total_tokens
=
11
)
def
test_raise_value_error_on_invalid_load_format
(
vllm_runner
):
with
pytest
.
raises
(
ValueError
):
vllm_runner
(
model_ref
,
load_format
=
"safetensors"
,
tensorizer_uri
=
"test"
)
model_loader_extra_config
=
TensorizerConfig
(
tensorizer_uri
=
"test"
,
vllm_tensorized
=
False
))
def
test_tensorizer_with_tp
(
vllm_runner
):
...
...
@@ -281,22 +316,12 @@ def test_tensorizer_with_tp(vllm_runner):
vllm_runner
(
model_ref
,
tensorizer_uri
=
tensorized_path
,
load_format
=
"tensorizer"
,
model_loader_extra_config
=
TensorizerConfig
(
tensorizer_uri
=
tensorized_path
,
num_readers
=
1
,
vllm_tensorized
=
False
,
s3_endpoint
=
"object.ord1.coreweave.com"
,
),
tensor_parallel_size
=
2
,
)
@
pytest
.
mark
.
skipif
(
not
is_curl_installed
(),
reason
=
"cURL is not installed"
)
def
test_tensorizer_warn_quant
(
tmp_path
):
model_ref
=
"LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit"
serialize_args
=
[
"python3"
,
"tensorizer/tensorize_vllm_model_for_testing.py"
,
"--model"
,
model_ref
,
"--quantization"
,
"gptq"
,
"--tensorizer-uri"
,
"test"
,
"serialize"
,
"--serialized-directory"
,
tmp_path
,
"--suffix"
,
"tests"
]
result
=
subprocess
.
run
(
serialize_args
,
capture_output
=
True
,
text
=
True
)
assert
'PerformanceWarning'
in
result
.
stderr
tests/test_config.py
View file @
69e1d2fb
...
...
@@ -11,8 +11,6 @@ def test_get_sliding_window():
"Qwen/Qwen1.5-7B"
,
tokenizer_mode
=
"auto"
,
trust_remote_code
=
False
,
download_dir
=
None
,
load_format
=
"dummy"
,
seed
=
0
,
dtype
=
"float16"
,
revision
=
None
,
...
...
@@ -30,8 +28,6 @@ def test_get_sliding_window():
"mistralai/Mistral-7B-v0.1"
,
tokenizer_mode
=
"auto"
,
trust_remote_code
=
False
,
download_dir
=
None
,
load_format
=
"dummy"
,
seed
=
0
,
dtype
=
"float16"
,
revision
=
None
,
...
...
tests/test_logits_processor.py
View file @
69e1d2fb
...
...
@@ -37,7 +37,12 @@ def _prepare_test(
1e-2
,
dtype
=
input_tensor
.
dtype
)
logits_processor
=
MockLogitsProcessor
(
32000
,
0.5
,
fake_logits
)
model_runner
=
ModelRunner
(
None
,
None
,
None
,
None
,
None
)
model_runner
=
ModelRunner
(
model_config
=
None
,
parallel_config
=
None
,
scheduler_config
=
None
,
device_config
=
None
,
load_config
=
None
,
lora_config
=
None
)
return
input_tensor
,
fake_logits
,
logits_processor
,
model_runner
...
...
tests/worker/test_model_runner.py
View file @
69e1d2fb
...
...
@@ -12,7 +12,12 @@ def test_prepare_prompt(batch_size):
100000
,
100000
,
enable_chunked_prefill
=
False
)
model_runner
=
ModelRunner
(
None
,
None
,
scheduler_config
,
None
,
None
)
model_runner
=
ModelRunner
(
model_config
=
None
,
parallel_config
=
None
,
scheduler_config
=
scheduler_config
,
device_config
=
None
,
load_config
=
None
,
lora_config
=
None
)
model_runner
.
set_block_size
(
16
)
prompt_lens
=
[]
...
...
@@ -118,8 +123,6 @@ def test_prepare_decode_cuda_graph(batch_size):
"facebook/opt-125m"
,
tokenizer_mode
=
"auto"
,
trust_remote_code
=
False
,
download_dir
=
None
,
load_format
=
"dummy"
,
seed
=
0
,
dtype
=
"float16"
,
revision
=
None
,
...
...
@@ -129,8 +132,12 @@ def test_prepare_decode_cuda_graph(batch_size):
100000
,
100000
,
enable_chunked_prefill
=
False
)
model_runner
=
ModelRunner
(
model_config
,
None
,
scheduler_config
,
None
,
None
)
model_runner
=
ModelRunner
(
model_config
=
model_config
,
parallel_config
=
None
,
scheduler_config
=
scheduler_config
,
device_config
=
None
,
load_config
=
None
,
lora_config
=
None
)
model_runner
.
set_block_size
(
16
)
prompt_lens
=
[]
...
...
@@ -205,14 +212,17 @@ def test_empty_seq_group():
"facebook/opt-125m"
,
tokenizer_mode
=
"auto"
,
trust_remote_code
=
False
,
download_dir
=
None
,
load_format
=
"dummy"
,
seed
=
0
,
dtype
=
"float16"
,
revision
=
None
,
enforce_eager
=
False
,
)
model_runner
=
ModelRunner
(
model_config
,
None
,
None
,
None
,
None
)
model_runner
=
ModelRunner
(
model_config
=
model_config
,
parallel_config
=
None
,
scheduler_config
=
None
,
device_config
=
None
,
load_config
=
None
,
lora_config
=
None
)
model_runner
.
set_block_size
(
16
)
seq_group_metadata_list
=
[]
input_tokens
,
input_positions
,
attn_metadata
,
_
,
_
,
_
,
slot_mapping
=
(
...
...
@@ -251,8 +261,6 @@ def test_hybrid_batches(batch_size, enforce_eager, monkeypatch):
"facebook/opt-125m"
,
tokenizer_mode
=
"auto"
,
trust_remote_code
=
False
,
download_dir
=
None
,
load_format
=
"dummy"
,
seed
=
0
,
dtype
=
"float16"
,
revision
=
None
,
...
...
@@ -262,11 +270,12 @@ def test_hybrid_batches(batch_size, enforce_eager, monkeypatch):
100000
,
100000
,
enable_chunked_prefill
=
True
)
model_runner
=
ModelRunner
(
model_config
,
None
,
scheduler_config
,
None
,
None
,
model_runner
=
ModelRunner
(
model_config
=
model_config
,
parallel_config
=
None
,
scheduler_config
=
scheduler_config
,
device_config
=
None
,
load_config
=
None
,
lora_config
=
None
,
is_driver_worker
=
True
)
model_runner
.
set_block_size
(
16
)
...
...
tests/worker/test_swap.py
View file @
69e1d2fb
...
...
@@ -23,6 +23,7 @@ def test_swap() -> None:
scheduler_config
=
engine_config
.
scheduler_config
,
device_config
=
engine_config
.
device_config
,
cache_config
=
engine_config
.
cache_config
,
load_config
=
engine_config
.
load_config
,
local_rank
=
0
,
rank
=
0
,
distributed_init_method
=
distributed_init_method
,
...
...
vllm/config.py
View file @
69e1d2fb
import
enum
import
io
import
json
import
os
import
typing
from
dataclasses
import
dataclass
,
fields
from
dataclasses
import
dataclass
,
field
,
fields
from
typing
import
TYPE_CHECKING
,
ClassVar
,
List
,
Optional
,
Union
import
torch
...
...
@@ -18,10 +16,14 @@ from vllm.utils import (get_cpu_memory, get_nvcc_cuda_version, is_cpu, is_hip,
if
TYPE_CHECKING
:
from
ray.util.placement_group
import
PlacementGroup
from
vllm.model_executor.
tensoriz
er
_
loader
import
TensorizerArgs
from
vllm.model_executor.
model_load
er
.
loader
import
BaseModelLoader
logger
=
init_logger
(
__name__
)
# If true, will load models from ModelScope instead of Hugging Face Hub.
VLLM_USE_MODELSCOPE
=
os
.
environ
.
get
(
"VLLM_USE_MODELSCOPE"
,
"False"
).
lower
()
==
"true"
_GB
=
1
<<
30
...
...
@@ -35,18 +37,6 @@ class ModelConfig:
available, and "slow" will always use the slow tokenizer.
trust_remote_code: Trust remote code (e.g., from HuggingFace) when
downloading the model and tokenizer.
download_dir: Directory to download and load the weights, default to the
default cache directory of huggingface.
load_format: The format of the model weights to load:
"auto" will try to load the weights in the safetensors format and
fall back to the pytorch bin format if safetensors format is
not available.
"pt" will load the weights in the pytorch bin format.
"safetensors" will load the weights in the safetensors format.
"npcache" will load the weights in pytorch format and store
a numpy cache to speed up the loading.
"dummy" will initialize the weights with random values, which is
mainly for profiling.
dtype: Data type for model weights and activations. The "auto" option
will use FP16 precision for FP32 and FP16 models, and BF16 precision
for BF16 models.
...
...
@@ -83,8 +73,6 @@ class ModelConfig:
tokenizer
:
str
,
tokenizer_mode
:
str
,
trust_remote_code
:
bool
,
download_dir
:
Optional
[
str
],
load_format
:
str
,
dtype
:
Union
[
str
,
torch
.
dtype
],
seed
:
int
,
revision
:
Optional
[
str
]
=
None
,
...
...
@@ -101,8 +89,6 @@ class ModelConfig:
self
.
tokenizer
=
tokenizer
self
.
tokenizer_mode
=
tokenizer_mode
self
.
trust_remote_code
=
trust_remote_code
self
.
download_dir
=
download_dir
self
.
load_format
=
load_format
self
.
seed
=
seed
self
.
revision
=
revision
self
.
code_revision
=
code_revision
...
...
@@ -113,64 +99,16 @@ class ModelConfig:
self
.
max_context_len_to_capture
=
max_context_len_to_capture
self
.
max_logprobs
=
max_logprobs
if
os
.
environ
.
get
(
"VLLM_USE_MODELSCOPE"
,
"False"
).
lower
()
==
"true"
:
# download model from ModelScope hub,
# lazy import so that modelscope is not required for normal use.
# pylint: disable=C.
from
modelscope.hub.snapshot_download
import
snapshot_download
if
not
os
.
path
.
exists
(
model
):
model_path
=
snapshot_download
(
model_id
=
model
,
cache_dir
=
download_dir
,
revision
=
revision
)
else
:
model_path
=
model
self
.
model
=
model_path
self
.
download_dir
=
model_path
self
.
tokenizer
=
model_path
self
.
hf_config
=
get_config
(
self
.
model
,
trust_remote_code
,
revision
,
code_revision
)
self
.
hf_text_config
=
get_hf_text_config
(
self
.
hf_config
)
self
.
dtype
=
_get_and_verify_dtype
(
self
.
hf_text_config
,
dtype
)
self
.
max_model_len
=
_get_and_verify_max_len
(
self
.
hf_text_config
,
max_model_len
)
self
.
_verify_load_format
()
self
.
_verify_tokenizer_mode
()
self
.
_verify_quantization
()
self
.
_verify_cuda_graph
()
def
_verify_load_format
(
self
)
->
None
:
load_format
=
self
.
load_format
.
lower
()
supported_load_format
=
[
"auto"
,
"pt"
,
"safetensors"
,
"npcache"
,
"dummy"
,
"tensorizer"
]
rocm_not_supported_load_format
:
List
[
str
]
=
[]
if
load_format
not
in
supported_load_format
:
raise
ValueError
(
f
"Unknown load format:
{
self
.
load_format
}
. Must be one of "
"'auto', 'pt', 'safetensors', 'npcache', 'tensorizer', or "
"'dummy'."
)
if
is_hip
()
and
load_format
in
rocm_not_supported_load_format
:
rocm_supported_load_format
=
[
f
for
f
in
supported_load_format
if
(
f
not
in
rocm_not_supported_load_format
)
]
raise
ValueError
(
f
"load format '
{
load_format
}
' is not supported in ROCm. "
f
"Supported load format are "
f
"
{
rocm_supported_load_format
}
"
)
# TODO: Remove this check once HF updates the pt weights of Mixtral.
architectures
=
getattr
(
self
.
hf_config
,
"architectures"
,
[])
# architectures can be None instead of []
if
architectures
and
"MixtralForCausalLM"
in
architectures
\
and
load_format
==
"pt"
:
raise
ValueError
(
"Currently, the 'pt' format is not supported for Mixtral. "
"Please use the 'safetensors' format instead. "
)
self
.
load_format
=
load_format
def
_verify_tokenizer_mode
(
self
)
->
None
:
tokenizer_mode
=
self
.
tokenizer_mode
.
lower
()
if
tokenizer_mode
not
in
[
"auto"
,
"slow"
]:
...
...
@@ -471,6 +409,65 @@ class TokenizerPoolConfig:
return
tokenizer_pool_config
class
LoadFormat
(
str
,
enum
.
Enum
):
AUTO
=
"auto"
PT
=
"pt"
SAFETENSORS
=
"safetensors"
NPCACHE
=
"npcache"
DUMMY
=
"dummy"
TENSORIZER
=
"tensorizer"
@
dataclass
class
LoadConfig
:
"""
download_dir: Directory to download and load the weights, default to the
default cache directory of huggingface.
load_format: The format of the model weights to load:
"auto" will try to load the weights in the safetensors format and
fall back to the pytorch bin format if safetensors format is
not available.
"pt" will load the weights in the pytorch bin format.
"safetensors" will load the weights in the safetensors format.
"npcache" will load the weights in pytorch format and store
a numpy cache to speed up the loading.
"dummy" will initialize the weights with random values, which is
mainly for profiling.
"tensorizer" will use CoreWeave's tensorizer library for
fast weight loading.
"""
load_format
:
Union
[
str
,
LoadFormat
,
"BaseModelLoader"
]
=
LoadFormat
.
AUTO
download_dir
:
Optional
[
str
]
=
None
model_loader_extra_config
:
Optional
[
Union
[
str
,
dict
]]
=
field
(
default_factory
=
dict
)
def
__post_init__
(
self
):
model_loader_extra_config
=
self
.
model_loader_extra_config
or
{}
if
isinstance
(
model_loader_extra_config
,
str
):
self
.
model_loader_extra_config
=
json
.
loads
(
model_loader_extra_config
)
self
.
_verify_load_format
()
def
_verify_load_format
(
self
)
->
None
:
if
not
isinstance
(
self
.
load_format
,
str
):
return
load_format
=
self
.
load_format
.
lower
()
self
.
load_format
=
LoadFormat
(
load_format
)
rocm_not_supported_load_format
:
List
[
str
]
=
[]
if
is_hip
()
and
load_format
in
rocm_not_supported_load_format
:
rocm_supported_load_format
=
[
f
for
f
in
LoadFormat
.
__members__
if
(
f
not
in
rocm_not_supported_load_format
)
]
raise
ValueError
(
f
"load format '
{
load_format
}
' is not supported in ROCm. "
f
"Supported load formats are "
f
"
{
rocm_supported_load_format
}
"
)
class
ParallelConfig
:
"""Configuration for the distributed execution.
...
...
@@ -699,8 +696,6 @@ class SpeculativeConfig:
tokenizer
=
target_model_config
.
tokenizer
,
tokenizer_mode
=
target_model_config
.
tokenizer_mode
,
trust_remote_code
=
target_model_config
.
trust_remote_code
,
download_dir
=
target_model_config
.
download_dir
,
load_format
=
target_model_config
.
load_format
,
dtype
=
target_model_config
.
dtype
,
seed
=
target_model_config
.
seed
,
revision
=
draft_revision
,
...
...
@@ -887,65 +882,6 @@ class VisionLanguageConfig:
f
"
{
[
x
.
name
for
x
in
cls
.
ImageInputType
]
}
."
)
from
e
@
dataclass
class
TensorizerConfig
:
tensorizer_uri
:
Union
[
io
.
BufferedIOBase
,
io
.
RawIOBase
,
typing
.
BinaryIO
,
str
,
bytes
,
os
.
PathLike
,
int
]
vllm_tensorized
:
bool
verify_hash
:
Optional
[
bool
]
=
False
num_readers
:
Optional
[
int
]
=
1
encryption_keyfile
:
Optional
[
str
]
=
None
s3_access_key_id
:
Optional
[
str
]
=
None
s3_secret_access_key
:
Optional
[
str
]
=
None
s3_endpoint
:
Optional
[
str
]
=
None
model_class
:
Optional
[
torch
.
nn
.
Module
]
=
None
hf_config
:
Optional
[
PretrainedConfig
]
=
None
dtype
:
Union
[
str
,
torch
.
dtype
]
=
None
def
_construct_tensorizer_args
(
self
)
->
"TensorizerArgs"
:
from
vllm.model_executor.tensorizer_loader
import
TensorizerArgs
tensorizer_args
=
{
"tensorizer_uri"
:
self
.
tensorizer_uri
,
"vllm_tensorized"
:
self
.
vllm_tensorized
,
"verify_hash"
:
self
.
verify_hash
,
"num_readers"
:
self
.
num_readers
,
"encryption_keyfile"
:
self
.
encryption_keyfile
,
"s3_access_key_id"
:
self
.
s3_access_key_id
,
"s3_secret_access_key"
:
self
.
s3_secret_access_key
,
"s3_endpoint"
:
self
.
s3_endpoint
,
}
return
TensorizerArgs
(
**
tensorizer_args
)
def
verify_with_parallel_config
(
self
,
parallel_config
:
"ParallelConfig"
,
)
->
None
:
if
(
parallel_config
.
tensor_parallel_size
>
1
and
self
.
tensorizer_uri
is
not
None
):
raise
ValueError
(
"Loading to multiple GPUs is not currently supported with "
"vLLM-serialized models. Please set tensor_parallel_size=1."
" or use a non-vLLM-serialized model, such as a "
"serialized Hugging Face `PretrainedModel`."
)
def
verify_with_model_config
(
self
,
model_config
)
->
None
:
if
(
model_config
.
quantization
is
not
None
and
self
.
tensorizer_uri
is
not
None
):
from
vllm.model_executor.tensorizer_loader
import
(
tensorizer_warning
)
tensorizer_warning
(
"Loading a model using Tensorizer with quantization on vLLM"
" is unstable and may lead to errors."
)
if
(
model_config
.
load_format
!=
"tensorizer"
and
self
.
tensorizer_uri
is
not
None
):
raise
ValueError
(
"A tensorizer uri was passed for tensorizer loading, but the "
f
"load format was set to
{
model_config
.
load_format
}
. "
"Please set the load format to 'tensorizer' to use "
f
"tensorizer args."
)
_STR_DTYPE_TO_TORCH_DTYPE
=
{
"half"
:
torch
.
float16
,
"float16"
:
torch
.
float16
,
...
...
@@ -1105,11 +1041,11 @@ class EngineConfig:
parallel_config
:
ParallelConfig
scheduler_config
:
SchedulerConfig
device_config
:
DeviceConfig
load_config
:
LoadConfig
lora_config
:
Optional
[
LoRAConfig
]
vision_language_config
:
Optional
[
VisionLanguageConfig
]
speculative_config
:
Optional
[
SpeculativeConfig
]
decoding_config
:
Optional
[
DecodingConfig
]
tensorizer_config
:
Optional
[
TensorizerConfig
]
def
__post_init__
(
self
):
"""Verify configs are valid & consistent with each other.
...
...
@@ -1117,11 +1053,6 @@ class EngineConfig:
self
.
model_config
.
verify_with_parallel_config
(
self
.
parallel_config
)
self
.
cache_config
.
verify_with_parallel_config
(
self
.
parallel_config
)
if
self
.
tensorizer_config
:
self
.
tensorizer_config
.
verify_with_parallel_config
(
self
.
parallel_config
)
self
.
tensorizer_config
.
verify_with_model_config
(
self
.
model_config
)
if
self
.
lora_config
:
self
.
lora_config
.
verify_with_model_config
(
self
.
model_config
)
self
.
lora_config
.
verify_with_scheduler_config
(
...
...
vllm/engine/arg_utils.py
View file @
69e1d2fb
import
argparse
import
dataclasses
import
io
import
os
from
dataclasses
import
dataclass
from
typing
import
BinaryIO
,
Optional
,
Union
from
typing
import
Optional
from
vllm.config
import
(
CacheConfig
,
DecodingConfig
,
DeviceConfig
,
EngineConfig
,
Lo
RA
Config
,
Model
Config
,
Parall
elConfig
,
SchedulerConfig
,
SpeculativeConfig
,
TensorizerConfig
,
EngineConfig
,
Lo
ad
Config
,
LoRA
Config
,
Mod
elConfig
,
ParallelConfig
,
SchedulerConfig
,
SpeculativeConfig
,
TokenizerPoolConfig
,
VisionLanguageConfig
)
from
vllm.model_executor.tensorizer_loader
import
TensorizerArgs
from
vllm.utils
import
str_to_int_tuple
...
...
@@ -60,17 +57,7 @@ class EngineArgs:
ray_workers_use_nsight
:
bool
=
False
num_gpu_blocks_override
:
Optional
[
int
]
=
None
num_lookahead_slots
:
int
=
0
# Tensorizer configuration parameters
tensorizer_uri
:
Union
[
io
.
BufferedIOBase
,
io
.
RawIOBase
,
BinaryIO
,
str
,
bytes
,
os
.
PathLike
,
int
]
=
None
vllm_tensorized
:
bool
=
False
verify_hash
:
Optional
[
bool
]
=
False
num_readers
:
Optional
[
int
]
=
1
encryption_keyfile
:
Optional
[
str
]
=
None
s3_access_key_id
:
Optional
[
str
]
=
None
s3_secret_access_key
:
Optional
[
str
]
=
None
s3_endpoint
:
Optional
[
str
]
=
None
model_loader_extra_config
:
Optional
[
dict
]
=
None
# Related to Vision-language models such as llava
image_input_type
:
Optional
[
str
]
=
None
...
...
@@ -429,7 +416,16 @@ class EngineArgs:
default
=
None
,
help
=
'The number of speculative tokens to sample from '
'the draft model in speculative decoding'
)
parser
=
TensorizerArgs
.
add_cli_args
(
parser
)
parser
.
add_argument
(
'--model-loader-extra-config'
,
type
=
str
,
default
=
EngineArgs
.
model_loader_extra_config
,
help
=
'Extra config for model loader. '
'This will be passed to the model loader '
'corresponding to the chosen load_format. '
'This should be a JSON string that will be '
'parsed into a dictionary.'
)
return
parser
@
classmethod
...
...
@@ -444,11 +440,11 @@ class EngineArgs:
device_config
=
DeviceConfig
(
self
.
device
)
model_config
=
ModelConfig
(
self
.
model
,
self
.
tokenizer
,
self
.
tokenizer_mode
,
self
.
trust_remote_code
,
self
.
d
ownload_dir
,
self
.
load_format
,
self
.
dtype
,
self
.
seed
,
self
.
revision
,
self
.
code_revisio
n
,
self
.
tokenizer_revision
,
self
.
max_model_le
n
,
self
.
quantization
,
self
.
quantization_param_path
,
self
.
enforce_eager
,
self
.
max_context_len_to_capture
,
self
.
max_logprobs
)
self
.
trust_remote_code
,
self
.
d
type
,
self
.
seed
,
self
.
revision
,
self
.
code_revision
,
self
.
tokenizer_
revision
,
self
.
max_model_le
n
,
self
.
quantizatio
n
,
self
.
quantization
_param_path
,
self
.
enforce_eager
,
self
.
max_context_len_to_capture
,
self
.
max_logprobs
)
cache_config
=
CacheConfig
(
self
.
block_size
,
self
.
gpu_memory_utilization
,
self
.
swap_space
,
self
.
kv_cache_dtype
,
...
...
@@ -492,15 +488,10 @@ class EngineArgs:
max_cpu_loras
=
self
.
max_cpu_loras
if
self
.
max_cpu_loras
and
self
.
max_cpu_loras
>
0
else
None
)
if
self
.
enable_lora
else
None
tensorizer_config
=
TensorizerConfig
(
tensorizer_uri
=
self
.
tensorizer_uri
,
vllm_tensorized
=
self
.
vllm_tensorized
,
verify_hash
=
self
.
verify_hash
,
num_readers
=
self
.
num_readers
,
encryption_keyfile
=
self
.
encryption_keyfile
,
s3_access_key_id
=
self
.
s3_access_key_id
,
s3_secret_access_key
=
self
.
s3_secret_access_key
,
s3_endpoint
=
self
.
s3_endpoint
,
load_config
=
LoadConfig
(
load_format
=
self
.
load_format
,
download_dir
=
self
.
download_dir
,
model_loader_extra_config
=
self
.
model_loader_extra_config
,
)
if
self
.
image_input_type
:
...
...
@@ -530,8 +521,8 @@ class EngineArgs:
lora_config
=
lora_config
,
vision_language_config
=
vision_language_config
,
speculative_config
=
speculative_config
,
decoding_config
=
decoding
_config
,
tensorizer_config
=
tensorizer
_config
)
load_config
=
load
_config
,
decoding_config
=
decoding
_config
)
@
dataclass
...
...
vllm/engine/llm_engine.py
View file @
69e1d2fb
...
...
@@ -4,9 +4,9 @@ from typing import Iterable, List, Optional, Tuple, Type, Union
from
transformers
import
PreTrainedTokenizer
import
vllm
from
vllm.config
import
(
CacheConfig
,
DecodingConfig
,
DeviceConfig
,
Lo
RA
Config
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
,
S
peculativeConfig
,
Tensorizer
Config
,
from
vllm.config
import
(
CacheConfig
,
DecodingConfig
,
DeviceConfig
,
Lo
ad
Config
,
LoRAConfig
,
ModelConfig
,
ParallelConfig
,
S
chedulerConfig
,
Speculative
Config
,
VisionLanguageConfig
)
from
vllm.core.scheduler
import
Scheduler
,
SchedulerOutputs
from
vllm.engine.arg_utils
import
EngineArgs
...
...
@@ -72,11 +72,11 @@ class LLMEngine:
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
,
device_config
:
DeviceConfig
,
load_config
:
LoadConfig
,
lora_config
:
Optional
[
LoRAConfig
],
vision_language_config
:
Optional
[
VisionLanguageConfig
],
speculative_config
:
Optional
[
SpeculativeConfig
],
decoding_config
:
Optional
[
DecodingConfig
],
tensorizer_config
:
Optional
[
TensorizerConfig
],
executor_class
:
Type
[
ExecutorBase
],
log_stats
:
bool
,
usage_context
:
UsageContext
=
UsageContext
.
ENGINE_CONTEXT
,
...
...
@@ -92,8 +92,8 @@ class LLMEngine:
f
"trust_remote_code=
{
model_config
.
trust_remote_code
}
, "
f
"dtype=
{
model_config
.
dtype
}
, "
f
"max_seq_len=
{
model_config
.
max_model_len
}
, "
f
"download_dir=
{
model
_config
.
download_dir
!
r
}
, "
f
"load_format=
{
model
_config
.
load_format
}
, "
f
"download_dir=
{
load
_config
.
download_dir
!
r
}
, "
f
"load_format=
{
load
_config
.
load_format
}
, "
f
"tensor_parallel_size=
{
parallel_config
.
tensor_parallel_size
}
, "
f
"disable_custom_all_reduce="
f
"
{
parallel_config
.
disable_custom_all_reduce
}
, "
...
...
@@ -114,8 +114,8 @@ class LLMEngine:
self
.
scheduler_config
=
scheduler_config
self
.
device_config
=
device_config
self
.
speculative_config
=
speculative_config
self
.
load_config
=
load_config
self
.
decoding_config
=
decoding_config
or
DecodingConfig
()
self
.
tensorizer_config
=
tensorizer_config
self
.
log_stats
=
log_stats
self
.
_init_tokenizer
()
...
...
@@ -131,7 +131,7 @@ class LLMEngine:
lora_config
=
lora_config
,
vision_language_config
=
vision_language_config
,
speculative_config
=
speculative_config
,
tensorizer_config
=
tensorizer
_config
,
load_config
=
load
_config
,
)
self
.
_initialize_kv_caches
()
...
...
@@ -271,9 +271,6 @@ class LLMEngine:
def
_verify_args
(
self
)
->
None
:
self
.
model_config
.
verify_with_parallel_config
(
self
.
parallel_config
)
self
.
cache_config
.
verify_with_parallel_config
(
self
.
parallel_config
)
if
self
.
tensorizer_config
:
self
.
tensorizer_config
.
verify_with_parallel_config
(
self
.
parallel_config
)
if
self
.
lora_config
:
self
.
lora_config
.
verify_with_model_config
(
self
.
model_config
)
self
.
lora_config
.
verify_with_scheduler_config
(
...
...
vllm/executor/cpu_executor.py
View file @
69e1d2fb
...
...
@@ -40,6 +40,7 @@ class CPUExecutor(ExecutorBase):
scheduler_config
=
self
.
scheduler_config
,
device_config
=
self
.
device_config
,
cache_config
=
self
.
cache_config
,
load_config
=
self
.
load_config
,
local_rank
=
0
,
rank
=
0
,
distributed_init_method
=
distributed_init_method
,
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment