Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0ee535b2
Unverified
Commit
0ee535b2
authored
May 09, 2024
by
Woosuk Kwon
Committed by
GitHub
May 09, 2024
Browse files
[Misc] Set block size at initialization & Fix test_model_runner (#4705)
parent
190bc838
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
64 additions
and
104 deletions
+64
-104
tests/worker/test_model_runner.py
tests/worker/test_model_runner.py
+32
-58
vllm/worker/cpu_model_runner.py
vllm/worker/cpu_model_runner.py
+9
-12
vllm/worker/cpu_worker.py
vllm/worker/cpu_worker.py
+1
-0
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+21
-33
vllm/worker/worker.py
vllm/worker/worker.py
+1
-1
No files found.
tests/worker/test_model_runner.py
View file @
0ee535b2
import
pytest
import
torch
from
vllm.config
import
ModelConfig
,
SchedulerConfig
from
vllm.distributed.parallel_state
import
init_distributed_environment
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
SamplingParams
,
SequenceData
,
SequenceGroupMetadata
from
vllm.utils
import
get_open_port
from
vllm.worker.model_runner
import
ModelRunner
,
_get_graph_batch_size
def
_create_model_runner
(
model
:
str
,
*
args
,
**
kwargs
)
->
ModelRunner
:
engine_args
=
EngineArgs
(
model
,
*
args
,
**
kwargs
)
engine_config
=
engine_args
.
create_engine_config
()
model_runner
=
ModelRunner
(
model_config
=
engine_config
.
model_config
,
parallel_config
=
engine_config
.
parallel_config
,
scheduler_config
=
engine_config
.
scheduler_config
,
device_config
=
engine_config
.
device_config
,
cache_config
=
engine_config
.
cache_config
,
load_config
=
engine_config
.
load_config
,
lora_config
=
engine_config
.
lora_config
,
is_driver_worker
=
True
,
)
return
model_runner
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
list
(
range
(
1
,
257
)))
def
test_prepare_prompt
(
batch_size
):
scheduler_config
=
SchedulerConfig
(
100000
,
100000
,
100000
,
enable_chunked_prefill
=
False
)
model_runner
=
ModelRunner
(
model_config
=
None
,
parallel_config
=
None
,
scheduler_config
=
scheduler_config
,
device_config
=
None
,
load_config
=
None
,
lora_config
=
None
)
model_runner
.
set_block_size
(
16
)
model_runner
=
_create_model_runner
(
"facebook/opt-125m"
,
max_num_batched_tokens
=
100000
,
max_num_seqs
=
100000
,
enable_chunked_prefill
=
False
,
)
seq_lens
=
[]
seq_group_metadata_list
=
[]
...
...
@@ -123,27 +134,15 @@ def test_prepare_prompt(batch_size):
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
list
(
range
(
1
,
257
)))
def
test_prepare_decode_cuda_graph
(
batch_size
):
model_
config
=
ModelConfig
(
model_
runner
=
_create_model_runner
(
"facebook/opt-125m"
,
"facebook/opt-125m"
,
tokenizer_mode
=
"auto"
,
trust_remote_code
=
False
,
seed
=
0
,
dtype
=
"float16"
,
revision
=
None
,
enforce_eager
=
False
,
max_num_batched_tokens
=
100000
,
max_num_seqs
=
100000
,
enable_chunked_prefill
=
False
,
)
scheduler_config
=
SchedulerConfig
(
100000
,
100000
,
100000
,
enable_chunked_prefill
=
False
)
model_runner
=
ModelRunner
(
model_config
=
model_config
,
parallel_config
=
None
,
scheduler_config
=
scheduler_config
,
device_config
=
None
,
load_config
=
None
,
lora_config
=
None
)
model_runner
.
set_block_size
(
16
)
seq_lens
=
[]
seq_group_metadata_list
=
[]
...
...
@@ -214,23 +213,12 @@ def test_prepare_decode_cuda_graph(batch_size):
def
test_empty_seq_group
():
"""Verify prepare prompt and decode returns empty output."""
model_config
=
ModelConfig
(
"facebook/opt-125m"
,
model_runner
=
_create_model_runner
(
"facebook/opt-125m"
,
tokenizer_mode
=
"auto"
,
trust_remote_code
=
False
,
seed
=
0
,
dtype
=
"float16"
,
revision
=
None
,
enforce_eager
=
False
,
)
model_runner
=
ModelRunner
(
model_config
=
model_config
,
parallel_config
=
None
,
scheduler_config
=
None
,
device_config
=
None
,
load_config
=
None
,
lora_config
=
None
)
model_runner
.
set_block_size
(
16
)
seq_group_metadata_list
=
[]
input_tokens
,
input_positions
,
attn_metadata
,
_
,
_
,
_
,
slot_mapping
=
(
model_runner
.
_prepare_decode
(
seq_group_metadata_list
))
...
...
@@ -260,29 +248,15 @@ def distributed_init():
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
list
(
range
(
2
,
128
)))
@
pytest
.
mark
.
parametrize
(
"enforce_eager"
,
[
True
,
False
])
def
test_hybrid_batches
(
batch_size
,
enforce_eager
,
distributed_init
):
model_config
=
ModelConfig
(
"facebook/opt-125m"
,
model_runner
=
_create_model_runner
(
"facebook/opt-125m"
,
tokenizer_mode
=
"auto"
,
trust_remote_code
=
False
,
seed
=
0
,
dtype
=
"float16"
,
revision
=
None
,
enforce_eager
=
enforce_eager
,
max_num_batched_tokens
=
100000
,
max_num_seqs
=
100000
,
enable_chunked_prefill
=
True
,
)
scheduler_config
=
SchedulerConfig
(
100000
,
100000
,
100000
,
enable_chunked_prefill
=
True
)
model_runner
=
ModelRunner
(
model_config
=
model_config
,
parallel_config
=
None
,
scheduler_config
=
scheduler_config
,
device_config
=
None
,
load_config
=
None
,
lora_config
=
None
,
is_driver_worker
=
True
)
model_runner
.
set_block_size
(
16
)
# Add prefill requests.
seq_lens
=
[]
...
...
vllm/worker/cpu_model_runner.py
View file @
0ee535b2
...
...
@@ -4,8 +4,9 @@ import torch
from
torch
import
nn
from
vllm.attention
import
AttentionMetadata
,
get_attn_backend
from
vllm.config
import
(
DeviceConfig
,
LoadConfig
,
LoRAConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
,
VisionLanguageConfig
)
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoadConfig
,
LoRAConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
,
VisionLanguageConfig
)
from
vllm.distributed
import
broadcast_tensor_dict
from
vllm.logger
import
init_logger
from
vllm.model_executor
import
SamplingMetadata
...
...
@@ -26,6 +27,7 @@ class CPUModelRunner:
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
,
device_config
:
DeviceConfig
,
cache_config
:
CacheConfig
,
load_config
:
LoadConfig
,
lora_config
:
Optional
[
LoRAConfig
],
vision_language_config
:
Optional
[
VisionLanguageConfig
],
...
...
@@ -39,27 +41,22 @@ class CPUModelRunner:
self
.
scheduler_config
=
scheduler_config
# Currently, CPU worker doesn't support chunked prefill.
assert
self
.
scheduler_config
.
chunked_prefill_enabled
is
False
self
.
device_config
=
device_config
self
.
cache_config
=
cache_config
self
.
lora_config
=
lora_config
self
.
vision_language_config
=
vision_language_config
self
.
load_config
=
load_config
self
.
is_driver_worker
=
is_driver_worker
# model_config can be None in tests/samplers/test_sampler.py.
# FIXME(woosuk): This is a hack to make the tests work. Refactor this.
self
.
sliding_window
=
(
model_config
.
get_sliding_window
()
if
model_config
is
not
None
else
None
)
self
.
device_config
=
(
device_config
if
device_config
is
not
None
else
DeviceConfig
())
self
.
device
=
self
.
device_config
.
device
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
attn_backend
=
get_attn_backend
(
self
.
model_config
.
dtype
if
model_config
is
not
None
else
Non
e
)
self
.
sliding_window
=
model_config
.
get_sliding_window
()
self
.
block_size
=
cache_config
.
block_size
self
.
attn_backend
=
get_attn_backend
(
self
.
model_config
.
dtyp
e
)
# Lazy initialization.
self
.
model
:
nn
.
Module
# Set after init_Model
self
.
block_size
:
int
# Set after initial profiling.
def
load_model
(
self
)
->
None
:
self
.
model
=
get_model
(
...
...
vllm/worker/cpu_worker.py
View file @
0ee535b2
...
...
@@ -151,6 +151,7 @@ class CPUWorker(LoraNotSupportedWorkerBase):
parallel_config
,
scheduler_config
,
device_config
,
cache_config
,
load_config
=
self
.
load_config
,
lora_config
=
self
.
lora_config
,
vision_language_config
=
self
.
vision_language_config
,
...
...
vllm/worker/model_runner.py
View file @
0ee535b2
...
...
@@ -9,8 +9,9 @@ import torch.nn as nn
from
vllm.attention
import
(
AttentionMetadata
,
AttentionMetadataPerStage
,
get_attn_backend
)
from
vllm.config
import
(
DeviceConfig
,
LoadConfig
,
LoRAConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
,
VisionLanguageConfig
)
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoadConfig
,
LoRAConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
,
VisionLanguageConfig
)
from
vllm.distributed
import
broadcast_tensor_dict
,
with_pynccl_for_all_reduce
from
vllm.distributed.device_communicators
import
(
custom_all_reduce
,
pynccl_utils
)
...
...
@@ -106,6 +107,7 @@ class ModelRunner:
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
,
device_config
:
DeviceConfig
,
cache_config
:
CacheConfig
,
load_config
:
LoadConfig
,
lora_config
:
Optional
[
LoRAConfig
],
kv_cache_dtype
:
Optional
[
str
]
=
"auto"
,
...
...
@@ -115,48 +117,40 @@ class ModelRunner:
self
.
model_config
=
model_config
self
.
parallel_config
=
parallel_config
self
.
scheduler_config
=
scheduler_config
self
.
device_config
=
device_config
self
.
cache_config
=
cache_config
self
.
lora_config
=
lora_config
self
.
load_config
=
load_config
self
.
is_driver_worker
=
is_driver_worker
self
.
vision_language_config
=
vision_language_config
# model_config can be None in tests/samplers/test_sampler.py.
# FIXME(woosuk): This is a hack to make the tests work. Refactor this.
self
.
sliding_window
=
(
model_config
.
get_sliding_window
()
if
model_config
is
not
None
else
None
)
self
.
device_config
=
(
device_config
if
device_config
is
not
None
else
DeviceConfig
())
self
.
device
=
self
.
device_config
.
device
self
.
pin_memory
=
is_pin_memory_available
()
# Set after load_model.
self
.
lora_manager
:
LRUCacheWorkerLoRAManager
=
None
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
sliding_window
=
model_config
.
get_sliding_window
()
self
.
block_size
=
cache_config
.
block_size
self
.
max_seq_len_to_capture
=
self
.
model_config
.
max_seq_len_to_capture
self
.
graph_runners
:
Dict
[
int
,
CUDAGraphRunner
]
=
{}
self
.
graph_memory_pool
:
Optional
[
Tuple
[
int
,
int
]]
=
None
# Set during graph capture.
self
.
max_seq_len_to_capture
=
(
self
.
model_config
.
max_seq_len_to_capture
if
self
.
model_config
is
not
None
else
0
)
self
.
pin_memory
=
is_pin_memory_available
()
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
vision_language_config
=
vision_language_config
self
.
attn_backend
=
get_attn_backend
(
self
.
model_config
.
dtype
if
model_config
is
not
None
else
None
)
# Lazy initialization
self
.
model
:
torch
.
nn
.
Module
# Set after load_model
self
.
block_size
:
int
# Set after initial profiling.
# When using CUDA graph, the input block tables must be padded to
# max_seq_len_to_capture. However, creating the block table in
# Python can be expensive. To optimize this, we cache the block table
# in numpy and only copy the actual input content at every iteration.
# The shape of the cached block table will be
# (max batch size to capture, max context len to capture / block size).
self
.
graph_block_tables
:
torch
.
Tensor
# Set after initial profiling.
self
.
graph_block_tables
=
np
.
zeros
(
(
max
(
_BATCH_SIZES_TO_CAPTURE
),
self
.
get_max_block_per_batch
()),
dtype
=
np
.
int32
)
self
.
attn_backend
=
get_attn_backend
(
self
.
model_config
.
dtype
)
# Lazy initialization
self
.
model
:
torch
.
nn
.
Module
# Set after load_model
# Set if the backend is flashinfer.
self
.
flashinfer_workspace_buffer
:
torch
.
Tensor
# Set after load_model.
self
.
lora_manager
:
Optional
[
LRUCacheWorkerLoRAManager
]
=
None
def
load_model
(
self
)
->
None
:
with
CudaMemoryProfiler
()
as
m
:
...
...
@@ -211,13 +205,6 @@ class ModelRunner:
"but the KV cache data type is not FP8. "
"KV cache scaling factors will not be used."
)
def
set_block_size
(
self
,
block_size
:
int
)
->
None
:
self
.
block_size
=
block_size
self
.
graph_block_tables
=
np
.
zeros
(
(
max
(
_BATCH_SIZES_TO_CAPTURE
),
self
.
get_max_block_per_batch
()),
dtype
=
np
.
int32
)
def
get_max_block_per_batch
(
self
)
->
int
:
block_size
=
self
.
block_size
return
(
self
.
max_seq_len_to_capture
+
block_size
-
1
)
//
block_size
...
...
@@ -835,6 +822,7 @@ class ModelRunner:
dummy_lora_requests
=
[]
dummy_lora_requests_per_seq
=
[]
if
self
.
lora_config
:
assert
self
.
lora_manager
is
not
None
with
self
.
lora_manager
.
dummy_lora_cache
():
for
idx
in
range
(
self
.
lora_config
.
max_loras
):
lora_id
=
idx
+
1
...
...
vllm/worker/worker.py
View file @
0ee535b2
...
...
@@ -75,6 +75,7 @@ class Worker(WorkerBase):
parallel_config
,
scheduler_config
,
device_config
,
cache_config
,
load_config
=
load_config
,
lora_config
=
self
.
lora_config
,
kv_cache_dtype
=
self
.
cache_config
.
cache_dtype
,
...
...
@@ -184,7 +185,6 @@ class Worker(WorkerBase):
self
.
cache_engine
=
CacheEngine
(
self
.
cache_config
,
self
.
model_config
,
self
.
parallel_config
)
self
.
gpu_cache
=
self
.
cache_engine
.
gpu_cache
self
.
model_runner
.
set_block_size
(
self
.
cache_engine
.
block_size
)
def
_warm_up_model
(
self
)
->
None
:
if
not
self
.
model_config
.
enforce_eager
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment