Unverified Commit 3682e33f authored by youkaichao's avatar youkaichao Committed by GitHub
Browse files

[v1] fix compilation cache (#11598)


Signed-off-by: default avataryoukaichao <youkaichao@gmail.com>
parent 0aa38d16
...@@ -7,7 +7,7 @@ if the config `tractable_init` is set to True. Otherwise, the weights are ...@@ -7,7 +7,7 @@ if the config `tractable_init` is set to True. Otherwise, the weights are
initialized randomly with a fixed seed. initialized randomly with a fixed seed.
""" """
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple from typing import Any, List, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
...@@ -54,6 +54,16 @@ class LlamaConfig: ...@@ -54,6 +54,16 @@ class LlamaConfig:
tractable_init: bool = False tractable_init: bool = False
random_seed: int = 0 random_seed: int = 0
def compute_hash(self) -> str:
factors: List[Any] = []
for k, v in self.__dict__.items():
if k == "random_seed":
continue
factors.append((k, v))
factors.sort()
import hashlib
return hashlib.md5(str(factors).encode()).hexdigest()
def __post_init__(self): def __post_init__(self):
assert self.mlp_size >= self.hidden_size assert self.mlp_size >= self.hidden_size
...@@ -263,7 +273,8 @@ def run_model(llama_config, ...@@ -263,7 +273,8 @@ def run_model(llama_config,
compilation_config = CompilationConfig( compilation_config = CompilationConfig(
level=CompilationLevel.NO_COMPILATION, ) level=CompilationLevel.NO_COMPILATION, )
vllm_config = VllmConfig(compilation_config=compilation_config) vllm_config = VllmConfig(compilation_config=compilation_config,
additional_config=llama_config)
with set_current_vllm_config(vllm_config): with set_current_vllm_config(vllm_config):
model = LlamaModel(config=llama_config, model = LlamaModel(config=llama_config,
vllm_config=vllm_config, vllm_config=vllm_config,
......
...@@ -619,8 +619,10 @@ class PiecewiseBackend: ...@@ -619,8 +619,10 @@ class PiecewiseBackend:
# the entries for different shapes that we need to either # the entries for different shapes that we need to either
# compile or capture cudagraph # compile or capture cudagraph
self.concrete_size_entries: Dict[int, ConcreteSizeEntry] = {} self.concrete_size_entries: Dict[int, ConcreteSizeEntry] = {}
self.to_be_compiled_sizes: Set[int] = self.compile_sizes.union(
self.capture_sizes) # to_be_compiled_sizes tracks the remaining sizes to compile,
# and updates during the compilation process, so we need to copy it
self.to_be_compiled_sizes: Set[int] = self.compile_sizes.copy()
for shape in self.compile_sizes.union(self.capture_sizes): for shape in self.compile_sizes.union(self.capture_sizes):
self.concrete_size_entries[shape] = ConcreteSizeEntry( self.concrete_size_entries[shape] = ConcreteSizeEntry(
runtime_shape=shape, runtime_shape=shape,
...@@ -628,12 +630,17 @@ class PiecewiseBackend: ...@@ -628,12 +630,17 @@ class PiecewiseBackend:
use_cudagraph=shape in self.capture_sizes, use_cudagraph=shape in self.capture_sizes,
) )
def check_for_ending_compilation(self):
if self.is_last_graph and not self.to_be_compiled_sizes:
# no specific sizes to compile
# save the hash of the inductor graph for the next run
self.compilation_config.inductor_hash_cache.save_to_file()
end_monitoring_torch_compile(self.vllm_config)
def __call__(self, *args) -> Any: def __call__(self, *args) -> Any:
if not self.first_run_finished: if not self.first_run_finished:
self.first_run_finished = True self.first_run_finished = True
# no specific sizes to compile self.check_for_ending_compilation()
if self.is_last_graph and not self.to_be_compiled_sizes:
end_monitoring_torch_compile(self.vllm_config)
return self.compiled_graph_for_general_shape(*args) return self.compiled_graph_for_general_shape(*args)
runtime_shape = args[self.sym_shape_indices[0]] runtime_shape = args[self.sym_shape_indices[0]]
...@@ -662,10 +669,7 @@ class PiecewiseBackend: ...@@ -662,10 +669,7 @@ class PiecewiseBackend:
# finished compilations for all required shapes # finished compilations for all required shapes
if self.is_last_graph and not self.to_be_compiled_sizes: if self.is_last_graph and not self.to_be_compiled_sizes:
self.check_for_ending_compilation()
# save the hash of the inductor graph for the next run
self.compilation_config.inductor_hash_cache.save_to_file()
end_monitoring_torch_compile(self.vllm_config)
if not entry.use_cudagraph: if not entry.use_cudagraph:
return entry.runnable(*args) return entry.runnable(*args)
......
...@@ -9,8 +9,8 @@ from contextlib import contextmanager ...@@ -9,8 +9,8 @@ from contextlib import contextmanager
from dataclasses import dataclass, field, replace from dataclasses import dataclass, field, replace
from pathlib import Path from pathlib import Path
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Counter, Dict, from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Counter, Dict,
Final, List, Literal, Mapping, Optional, Set, Tuple, Type, Final, List, Literal, Mapping, Optional, Protocol, Set,
Union) Tuple, Type, Union)
import torch import torch
from pydantic import BaseModel, Field, PrivateAttr from pydantic import BaseModel, Field, PrivateAttr
...@@ -75,6 +75,12 @@ HfOverrides = Union[Dict[str, Any], Callable[[PretrainedConfig], ...@@ -75,6 +75,12 @@ HfOverrides = Union[Dict[str, Any], Callable[[PretrainedConfig],
PretrainedConfig]] PretrainedConfig]]
class SupportsHash(Protocol):
def compute_hash(self) -> str:
...
class ModelConfig: class ModelConfig:
"""Configuration for the model. """Configuration for the model.
...@@ -2969,6 +2975,10 @@ class VllmConfig: ...@@ -2969,6 +2975,10 @@ class VllmConfig:
init=True) # type: ignore init=True) # type: ignore
kv_transfer_config: KVTransferConfig = field(default=None, kv_transfer_config: KVTransferConfig = field(default=None,
init=True) # type: ignore init=True) # type: ignore
# some opaque config, only used to provide additional information
# for the hash computation, mainly used for testing and debugging.
additional_config: SupportsHash = field(default=None,
init=True) # type: ignore
instance_id: str = "" instance_id: str = ""
def compute_hash(self) -> str: def compute_hash(self) -> str:
...@@ -3000,33 +3010,62 @@ class VllmConfig: ...@@ -3000,33 +3010,62 @@ class VllmConfig:
vllm_factors.append(__version__) vllm_factors.append(__version__)
if self.model_config: if self.model_config:
vllm_factors.append(self.model_config.compute_hash()) vllm_factors.append(self.model_config.compute_hash())
else:
vllm_factors.append("None")
if self.cache_config: if self.cache_config:
vllm_factors.append(self.cache_config.compute_hash()) vllm_factors.append(self.cache_config.compute_hash())
else:
vllm_factors.append("None")
if self.parallel_config: if self.parallel_config:
vllm_factors.append(self.parallel_config.compute_hash()) vllm_factors.append(self.parallel_config.compute_hash())
else:
vllm_factors.append("None")
if self.scheduler_config: if self.scheduler_config:
vllm_factors.append(self.scheduler_config.compute_hash()) vllm_factors.append(self.scheduler_config.compute_hash())
else:
vllm_factors.append("None")
if self.device_config: if self.device_config:
vllm_factors.append(self.device_config.compute_hash()) vllm_factors.append(self.device_config.compute_hash())
else:
vllm_factors.append("None")
if self.load_config: if self.load_config:
vllm_factors.append(self.load_config.compute_hash()) vllm_factors.append(self.load_config.compute_hash())
else:
vllm_factors.append("None")
if self.lora_config: if self.lora_config:
vllm_factors.append(self.lora_config.compute_hash()) vllm_factors.append(self.lora_config.compute_hash())
else:
vllm_factors.append("None")
if self.speculative_config: if self.speculative_config:
vllm_factors.append(self.speculative_config.compute_hash()) vllm_factors.append(self.speculative_config.compute_hash())
else:
vllm_factors.append("None")
if self.decoding_config: if self.decoding_config:
vllm_factors.append(self.decoding_config.compute_hash()) vllm_factors.append(self.decoding_config.compute_hash())
else:
vllm_factors.append("None")
if self.observability_config: if self.observability_config:
vllm_factors.append(self.observability_config.compute_hash()) vllm_factors.append(self.observability_config.compute_hash())
else:
vllm_factors.append("None")
if self.prompt_adapter_config: if self.prompt_adapter_config:
vllm_factors.append(self.prompt_adapter_config.compute_hash()) vllm_factors.append(self.prompt_adapter_config.compute_hash())
else:
vllm_factors.append("None")
if self.quant_config: if self.quant_config:
pass # should be captured by model_config.quantization pass # should be captured by model_config.quantization
if self.compilation_config: if self.compilation_config:
vllm_factors.append(self.compilation_config.compute_hash()) vllm_factors.append(self.compilation_config.compute_hash())
else:
vllm_factors.append("None")
if self.kv_transfer_config: if self.kv_transfer_config:
vllm_factors.append(self.kv_transfer_config.compute_hash()) vllm_factors.append(self.kv_transfer_config.compute_hash())
else:
vllm_factors.append("None")
if self.additional_config:
vllm_factors.append(self.additional_config.compute_hash())
else:
vllm_factors.append("None")
factors.append(vllm_factors) factors.append(vllm_factors)
hash_str = hashlib.md5(str(factors).encode()).hexdigest()[:10] hash_str = hashlib.md5(str(factors).encode()).hexdigest()[:10]
......
...@@ -48,6 +48,7 @@ class Worker: ...@@ -48,6 +48,7 @@ class Worker:
self.prompt_adapter_config = vllm_config.prompt_adapter_config self.prompt_adapter_config = vllm_config.prompt_adapter_config
self.observability_config = vllm_config.observability_config self.observability_config = vllm_config.observability_config
self.parallel_config.rank = rank
self.local_rank = local_rank self.local_rank = local_rank
self.rank = rank self.rank = rank
self.distributed_init_method = distributed_init_method self.distributed_init_method = distributed_init_method
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment