Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3682e33f
Unverified
Commit
3682e33f
authored
Dec 30, 2024
by
youkaichao
Committed by
GitHub
Dec 30, 2024
Browse files
[v1] fix compilation cache (#11598)
Signed-off-by:
youkaichao
<
youkaichao@gmail.com
>
parent
0aa38d16
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
69 additions
and
14 deletions
+69
-14
tests/compile/piecewise/test_toy_llama.py
tests/compile/piecewise/test_toy_llama.py
+13
-2
vllm/compilation/backends.py
vllm/compilation/backends.py
+13
-9
vllm/config.py
vllm/config.py
+42
-3
vllm/v1/worker/gpu_worker.py
vllm/v1/worker/gpu_worker.py
+1
-0
No files found.
tests/compile/piecewise/test_toy_llama.py
View file @
3682e33f
...
...
@@ -7,7 +7,7 @@ if the config `tractable_init` is set to True. Otherwise, the weights are
initialized randomly with a fixed seed.
"""
from
dataclasses
import
dataclass
from
typing
import
Optional
,
Tuple
from
typing
import
Any
,
List
,
Optional
,
Tuple
import
torch
from
torch
import
nn
...
...
@@ -54,6 +54,16 @@ class LlamaConfig:
tractable_init
:
bool
=
False
random_seed
:
int
=
0
def
compute_hash
(
self
)
->
str
:
factors
:
List
[
Any
]
=
[]
for
k
,
v
in
self
.
__dict__
.
items
():
if
k
==
"random_seed"
:
continue
factors
.
append
((
k
,
v
))
factors
.
sort
()
import
hashlib
return
hashlib
.
md5
(
str
(
factors
).
encode
()).
hexdigest
()
def
__post_init__
(
self
):
assert
self
.
mlp_size
>=
self
.
hidden_size
...
...
@@ -263,7 +273,8 @@ def run_model(llama_config,
compilation_config
=
CompilationConfig
(
level
=
CompilationLevel
.
NO_COMPILATION
,
)
vllm_config
=
VllmConfig
(
compilation_config
=
compilation_config
)
vllm_config
=
VllmConfig
(
compilation_config
=
compilation_config
,
additional_config
=
llama_config
)
with
set_current_vllm_config
(
vllm_config
):
model
=
LlamaModel
(
config
=
llama_config
,
vllm_config
=
vllm_config
,
...
...
vllm/compilation/backends.py
View file @
3682e33f
...
...
@@ -619,8 +619,10 @@ class PiecewiseBackend:
# the entries for different shapes that we need to either
# compile or capture cudagraph
self
.
concrete_size_entries
:
Dict
[
int
,
ConcreteSizeEntry
]
=
{}
self
.
to_be_compiled_sizes
:
Set
[
int
]
=
self
.
compile_sizes
.
union
(
self
.
capture_sizes
)
# to_be_compiled_sizes tracks the remaining sizes to compile,
# and updates during the compilation process, so we need to copy it
self
.
to_be_compiled_sizes
:
Set
[
int
]
=
self
.
compile_sizes
.
copy
()
for
shape
in
self
.
compile_sizes
.
union
(
self
.
capture_sizes
):
self
.
concrete_size_entries
[
shape
]
=
ConcreteSizeEntry
(
runtime_shape
=
shape
,
...
...
@@ -628,12 +630,17 @@ class PiecewiseBackend:
use_cudagraph
=
shape
in
self
.
capture_sizes
,
)
def
check_for_ending_compilation
(
self
):
if
self
.
is_last_graph
and
not
self
.
to_be_compiled_sizes
:
# no specific sizes to compile
# save the hash of the inductor graph for the next run
self
.
compilation_config
.
inductor_hash_cache
.
save_to_file
()
end_monitoring_torch_compile
(
self
.
vllm_config
)
def
__call__
(
self
,
*
args
)
->
Any
:
if
not
self
.
first_run_finished
:
self
.
first_run_finished
=
True
# no specific sizes to compile
if
self
.
is_last_graph
and
not
self
.
to_be_compiled_sizes
:
end_monitoring_torch_compile
(
self
.
vllm_config
)
self
.
check_for_ending_compilation
()
return
self
.
compiled_graph_for_general_shape
(
*
args
)
runtime_shape
=
args
[
self
.
sym_shape_indices
[
0
]]
...
...
@@ -662,10 +669,7 @@ class PiecewiseBackend:
# finished compilations for all required shapes
if
self
.
is_last_graph
and
not
self
.
to_be_compiled_sizes
:
# save the hash of the inductor graph for the next run
self
.
compilation_config
.
inductor_hash_cache
.
save_to_file
()
end_monitoring_torch_compile
(
self
.
vllm_config
)
self
.
check_for_ending_compilation
()
if
not
entry
.
use_cudagraph
:
return
entry
.
runnable
(
*
args
)
...
...
vllm/config.py
View file @
3682e33f
...
...
@@ -9,8 +9,8 @@ from contextlib import contextmanager
from
dataclasses
import
dataclass
,
field
,
replace
from
pathlib
import
Path
from
typing
import
(
TYPE_CHECKING
,
Any
,
Callable
,
ClassVar
,
Counter
,
Dict
,
Final
,
List
,
Literal
,
Mapping
,
Optional
,
Set
,
Tuple
,
Type
,
Union
)
Final
,
List
,
Literal
,
Mapping
,
Optional
,
Protocol
,
Set
,
Tuple
,
Type
,
Union
)
import
torch
from
pydantic
import
BaseModel
,
Field
,
PrivateAttr
...
...
@@ -75,6 +75,12 @@ HfOverrides = Union[Dict[str, Any], Callable[[PretrainedConfig],
PretrainedConfig
]]
class
SupportsHash
(
Protocol
):
def
compute_hash
(
self
)
->
str
:
...
class
ModelConfig
:
"""Configuration for the model.
...
...
@@ -2969,6 +2975,10 @@ class VllmConfig:
init
=
True
)
# type: ignore
kv_transfer_config
:
KVTransferConfig
=
field
(
default
=
None
,
init
=
True
)
# type: ignore
# some opaque config, only used to provide additional information
# for the hash computation, mainly used for testing and debugging.
additional_config
:
SupportsHash
=
field
(
default
=
None
,
init
=
True
)
# type: ignore
instance_id
:
str
=
""
def
compute_hash
(
self
)
->
str
:
...
...
@@ -3000,33 +3010,62 @@ class VllmConfig:
vllm_factors
.
append
(
__version__
)
if
self
.
model_config
:
vllm_factors
.
append
(
self
.
model_config
.
compute_hash
())
else
:
vllm_factors
.
append
(
"None"
)
if
self
.
cache_config
:
vllm_factors
.
append
(
self
.
cache_config
.
compute_hash
())
else
:
vllm_factors
.
append
(
"None"
)
if
self
.
parallel_config
:
vllm_factors
.
append
(
self
.
parallel_config
.
compute_hash
())
else
:
vllm_factors
.
append
(
"None"
)
if
self
.
scheduler_config
:
vllm_factors
.
append
(
self
.
scheduler_config
.
compute_hash
())
else
:
vllm_factors
.
append
(
"None"
)
if
self
.
device_config
:
vllm_factors
.
append
(
self
.
device_config
.
compute_hash
())
else
:
vllm_factors
.
append
(
"None"
)
if
self
.
load_config
:
vllm_factors
.
append
(
self
.
load_config
.
compute_hash
())
else
:
vllm_factors
.
append
(
"None"
)
if
self
.
lora_config
:
vllm_factors
.
append
(
self
.
lora_config
.
compute_hash
())
else
:
vllm_factors
.
append
(
"None"
)
if
self
.
speculative_config
:
vllm_factors
.
append
(
self
.
speculative_config
.
compute_hash
())
else
:
vllm_factors
.
append
(
"None"
)
if
self
.
decoding_config
:
vllm_factors
.
append
(
self
.
decoding_config
.
compute_hash
())
else
:
vllm_factors
.
append
(
"None"
)
if
self
.
observability_config
:
vllm_factors
.
append
(
self
.
observability_config
.
compute_hash
())
else
:
vllm_factors
.
append
(
"None"
)
if
self
.
prompt_adapter_config
:
vllm_factors
.
append
(
self
.
prompt_adapter_config
.
compute_hash
())
else
:
vllm_factors
.
append
(
"None"
)
if
self
.
quant_config
:
pass
# should be captured by model_config.quantization
if
self
.
compilation_config
:
vllm_factors
.
append
(
self
.
compilation_config
.
compute_hash
())
else
:
vllm_factors
.
append
(
"None"
)
if
self
.
kv_transfer_config
:
vllm_factors
.
append
(
self
.
kv_transfer_config
.
compute_hash
())
else
:
vllm_factors
.
append
(
"None"
)
if
self
.
additional_config
:
vllm_factors
.
append
(
self
.
additional_config
.
compute_hash
())
else
:
vllm_factors
.
append
(
"None"
)
factors
.
append
(
vllm_factors
)
hash_str
=
hashlib
.
md5
(
str
(
factors
).
encode
()).
hexdigest
()[:
10
]
...
...
vllm/v1/worker/gpu_worker.py
View file @
3682e33f
...
...
@@ -48,6 +48,7 @@ class Worker:
self
.
prompt_adapter_config
=
vllm_config
.
prompt_adapter_config
self
.
observability_config
=
vllm_config
.
observability_config
self
.
parallel_config
.
rank
=
rank
self
.
local_rank
=
local_rank
self
.
rank
=
rank
self
.
distributed_init_method
=
distributed_init_method
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment