Unverified Commit 73202dbe authored by bnellnm's avatar bnellnm Committed by GitHub
Browse files

[Kernel][Misc] register ops to prevent graph breaks (#6917)


Co-authored-by: default avatarSage Moore <sage@neuralmagic.com>
parent 7015417f
......@@ -75,6 +75,10 @@ _NUM_WARMUP_ITERS = 2
TModelInputForGPU = TypeVar('TModelInputForGPU', bound="ModelInputForGPU")
# For now, bump up cache limits for recompilations during CUDA graph warmups.
torch._dynamo.config.cache_size_limit = 128
torch._dynamo.config.accumulated_cache_size_limit = 128
@dataclass(frozen=True)
class ModelInputForGPU(ModelRunnerInputBase):
......@@ -1060,9 +1064,10 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
"This may lead to less accurate results!")
if envs.VLLM_TEST_DYNAMO_GRAPH_CAPTURE and supports_dynamo():
self.model = torch.compile(self.model,
fullgraph=True,
backend="eager")
self.model = torch.compile(
self.model,
fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
backend="eager")
def save_sharded_state(
self,
......
......@@ -166,6 +166,7 @@ class Worker(LocalOrDistributedWorkerBase):
torch.cuda.set_device(self.device)
_check_if_gpu_supports_dtype(self.model_config.dtype)
gc.collect()
torch.cuda.empty_cache()
self.init_gpu_memory = torch.cuda.mem_get_info()[0]
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment