Unverified Commit 16d2ad1d authored by Kunshang Ji's avatar Kunshang Ji Committed by GitHub
Browse files

[Hardware] Replace `torch.cuda.empty_cache` with `torch.accelerator.empty_cache` (#30681)


Signed-off-by: default avatarKunshang Ji <kunshang.ji@intel.com>
Signed-off-by: default avatarKunshang Ji <jikunshang95@gmail.com>
Co-authored-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 5dc35387
...@@ -530,7 +530,7 @@ def test_logprobs_mode(logprobs_mode: LogprobsMode): ...@@ -530,7 +530,7 @@ def test_logprobs_mode(logprobs_mode: LogprobsMode):
assert positive_values > 0 assert positive_values > 0
finally: finally:
del llm del llm
torch.cuda.empty_cache() torch.accelerator.empty_cache()
cleanup_dist_env_and_memory() cleanup_dist_env_and_memory()
...@@ -1065,7 +1065,7 @@ def test_spec_decode_logprobs( ...@@ -1065,7 +1065,7 @@ def test_spec_decode_logprobs(
for logprobs in output.logprobs: for logprobs in output.logprobs:
ref_logprobs.extend(logprobs.values()) ref_logprobs.extend(logprobs.values())
del ref_llm del ref_llm
torch.cuda.empty_cache() torch.accelerator.empty_cache()
cleanup_dist_env_and_memory() cleanup_dist_env_and_memory()
# Run spec decode LLM. # Run spec decode LLM.
...@@ -1095,7 +1095,7 @@ def test_spec_decode_logprobs( ...@@ -1095,7 +1095,7 @@ def test_spec_decode_logprobs(
for logprobs in output.logprobs: for logprobs in output.logprobs:
spec_logprobs.extend(logprobs.values()) spec_logprobs.extend(logprobs.values())
del spec_llm del spec_llm
torch.cuda.empty_cache() torch.accelerator.empty_cache()
cleanup_dist_env_and_memory() cleanup_dist_env_and_memory()
# Per-token logprobs are expected to be the same. # Per-token logprobs are expected to be the same.
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import sys
import regex as re
# --------------------------------------------------------------------------- #
# Regex: match `torch.cuda.xxx` but allow `torch.accelerator.xxx`
# --------------------------------------------------------------------------- #
_TORCH_CUDA_PATTERNS = [
r"\btorch\.cuda\.empty_cache\b",
]
ALLOWED_FILES = {"vllm/platforms/", "vllm/device_allocator/"}
def scan_file(path: str) -> int:
with open(path, encoding="utf-8") as f:
content = f.read()
for pattern in _TORCH_CUDA_PATTERNS:
for match in re.finditer(pattern, content, re.MULTILINE):
# Calculate line number from match position
line_num = content[: match.start() + 1].count("\n") + 1
print(
f"{path}:{line_num}: "
"\033[91merror:\033[0m " # red color
"Found torch.cuda API call"
)
return 1
return 0
def main():
returncode = 0
for filename in sys.argv[1:]:
if any(filename.startswith(prefix) for prefix in ALLOWED_FILES):
continue
returncode |= scan_file(filename)
return returncode
if __name__ == "__main__":
sys.exit(main())
...@@ -260,7 +260,9 @@ class CUDAGraphWrapper: ...@@ -260,7 +260,9 @@ class CUDAGraphWrapper:
# therefore, we only run gc for the first graph, # therefore, we only run gc for the first graph,
# and disable gc for the rest of the graphs. # and disable gc for the rest of the graphs.
stack.enter_context(patch("gc.collect", lambda: None)) stack.enter_context(patch("gc.collect", lambda: None))
stack.enter_context(patch("torch.cuda.empty_cache", lambda: None)) stack.enter_context(
patch("torch.accelerator.empty_cache", lambda: None)
)
if self.graph_pool is not None: if self.graph_pool is not None:
set_graph_pool_id(self.graph_pool) set_graph_pool_id(self.graph_pool)
......
...@@ -408,7 +408,7 @@ class ElasticEPScalingExecutor: ...@@ -408,7 +408,7 @@ class ElasticEPScalingExecutor:
gc.collect() gc.collect()
torch.cuda.synchronize() torch.cuda.synchronize()
torch.cuda.empty_cache() torch.accelerator.empty_cache()
unlock_workspace() unlock_workspace()
self.worker.compile_or_warm_up_model() self.worker.compile_or_warm_up_model()
lock_workspace() lock_workspace()
......
...@@ -1916,14 +1916,14 @@ def cleanup_dist_env_and_memory(shutdown_ray: bool = False): ...@@ -1916,14 +1916,14 @@ def cleanup_dist_env_and_memory(shutdown_ray: bool = False):
gc.collect() gc.collect()
from vllm.platforms import current_platform from vllm.platforms import current_platform
empty_cache = current_platform.empty_cache if not current_platform.is_cpu():
if empty_cache is not None: torch.accelerator.empty_cache()
empty_cache() try:
try:
if not current_platform.is_cpu():
torch._C._host_emptyCache() torch._C._host_emptyCache()
except AttributeError: except AttributeError:
logger.warning("torch._C._host_emptyCache() only available in Pytorch >=2.5") logger.warning(
"torch._C._host_emptyCache() only available in Pytorch >=2.5"
)
def in_the_same_node_as( def in_the_same_node_as(
......
...@@ -200,7 +200,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -200,7 +200,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
): ):
num_pad = 256 // weight.element_size() num_pad = 256 // weight.element_size()
weight = F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad] weight = F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad]
torch.cuda.empty_cache() torch.accelerator.empty_cache()
return weight return weight
......
...@@ -961,7 +961,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod): ...@@ -961,7 +961,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
# secondly, process mxfp weights # secondly, process mxfp weights
if self.emulate: if self.emulate:
torch.cuda.empty_cache() torch.accelerator.empty_cache()
return return
from aiter.utility.fp4_utils import e8m0_shuffle from aiter.utility.fp4_utils import e8m0_shuffle
...@@ -995,7 +995,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod): ...@@ -995,7 +995,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False) layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False)
layer.w13_weight.is_shuffled = True layer.w13_weight.is_shuffled = True
layer.w2_weight.is_shuffled = True layer.w2_weight.is_shuffled = True
torch.cuda.empty_cache() torch.accelerator.empty_cache()
def get_fused_moe_quant_config( def get_fused_moe_quant_config(
self, layer: torch.nn.Module self, layer: torch.nn.Module
...@@ -1116,7 +1116,7 @@ class QuarkOCP_MX_MoEMethod_OSS(QuarkOCP_MX_MoEMethod): ...@@ -1116,7 +1116,7 @@ class QuarkOCP_MX_MoEMethod_OSS(QuarkOCP_MX_MoEMethod):
del layer.w2_weight del layer.w2_weight
layer.w13_weight = None layer.w13_weight = None
layer.w2_weight = None layer.w2_weight = None
torch.cuda.empty_cache() torch.accelerator.empty_cache()
if self.static_input_scales: if self.static_input_scales:
if layer.w13_input_scale is None or layer.w2_input_scale is None: if layer.w13_input_scale is None or layer.w2_input_scale is None:
......
...@@ -1407,7 +1407,7 @@ def _maybe_pad_fp8_weight(weight: torch.Tensor) -> torch.Tensor: ...@@ -1407,7 +1407,7 @@ def _maybe_pad_fp8_weight(weight: torch.Tensor) -> torch.Tensor:
import torch.nn.functional as F import torch.nn.functional as F
weight = F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad] weight = F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad]
torch.cuda.empty_cache() torch.accelerator.empty_cache()
return weight return weight
......
...@@ -811,7 +811,7 @@ class BitsAndBytesModelLoader(BaseModelLoader): ...@@ -811,7 +811,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
**stacked_quant_state_dict, **stacked_quant_state_dict,
} }
self._bind_quant_states_to_params(model, stacked_quant_state_dict) self._bind_quant_states_to_params(model, stacked_quant_state_dict)
torch.cuda.empty_cache() torch.accelerator.empty_cache()
def download_model(self, model_config: ModelConfig) -> None: def download_model(self, model_config: ModelConfig) -> None:
self._prepare_weights(model_config.model, model_config.revision) self._prepare_weights(model_config.model, model_config.revision)
...@@ -96,7 +96,7 @@ class MemorySnapshot: ...@@ -96,7 +96,7 @@ class MemorySnapshot:
# rather than `torch.cuda.memory_reserved()` . # rather than `torch.cuda.memory_reserved()` .
# After `torch.cuda.reset_peak_memory_stats()`, # After `torch.cuda.reset_peak_memory_stats()`,
# `torch.cuda.memory_reserved()` will keep growing, and only shrink # `torch.cuda.memory_reserved()` will keep growing, and only shrink
# when we call `torch.cuda.empty_cache()` or OOM happens. # when we call `torch.accelerator.empty_cache()` or OOM happens.
self.torch_peak = current_platform.memory_stats(device).get( self.torch_peak = current_platform.memory_stats(device).get(
"allocated_bytes.all.peak", 0 "allocated_bytes.all.peak", 0
) )
...@@ -250,7 +250,7 @@ def memory_profiling( ...@@ -250,7 +250,7 @@ def memory_profiling(
until after profiling to get (c.). until after profiling to get (c.).
""" """
gc.collect() gc.collect()
current_platform.empty_cache() torch.accelerator.empty_cache()
current_platform.reset_peak_memory_stats(baseline_snapshot.device_) current_platform.reset_peak_memory_stats(baseline_snapshot.device_)
result = MemoryProfilingResult( result = MemoryProfilingResult(
...@@ -264,7 +264,7 @@ def memory_profiling( ...@@ -264,7 +264,7 @@ def memory_profiling(
yield result yield result
gc.collect() gc.collect()
current_platform.empty_cache() torch.accelerator.empty_cache()
result.after_profile.measure() result.after_profile.measure()
......
...@@ -1036,4 +1036,4 @@ def apply_top_k_top_p_triton( ...@@ -1036,4 +1036,4 @@ def apply_top_k_top_p_triton(
def reset_buffer_cache(): def reset_buffer_cache():
_TRITON_BUFFER_CACHE.clear() _TRITON_BUFFER_CACHE.clear()
_TRITON_TABLE_CACHE.clear() _TRITON_TABLE_CACHE.clear()
torch.cuda.empty_cache() torch.accelerator.empty_cache()
...@@ -496,7 +496,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -496,7 +496,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
start_time = time.perf_counter() start_time = time.perf_counter()
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.accelerator.empty_cache()
start_free_gpu_memory = torch.cuda.mem_get_info()[0] start_free_gpu_memory = torch.cuda.mem_get_info()[0]
with self.maybe_setup_dummy_loras(self.lora_config): with self.maybe_setup_dummy_loras(self.lora_config):
......
...@@ -278,7 +278,7 @@ class Worker(WorkerBase): ...@@ -278,7 +278,7 @@ class Worker(WorkerBase):
# Now take memory snapshot after NCCL is initialized # Now take memory snapshot after NCCL is initialized
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.accelerator.empty_cache()
# take current memory snapshot # take current memory snapshot
self.init_snapshot = init_snapshot = MemorySnapshot(device=self.device) self.init_snapshot = init_snapshot = MemorySnapshot(device=self.device)
...@@ -585,7 +585,7 @@ class Worker(WorkerBase): ...@@ -585,7 +585,7 @@ class Worker(WorkerBase):
# sampling related tensors of max possible shape to avoid memory # sampling related tensors of max possible shape to avoid memory
# fragmentation issue. # fragmentation issue.
# NOTE: This is called after `capture_model` on purpose to prevent # NOTE: This is called after `capture_model` on purpose to prevent
# memory buffers from being cleared by `torch.cuda.empty_cache`. # memory buffers from being cleared by `torch.accelerator.empty_cache`.
max_num_reqs = min( max_num_reqs = min(
self.scheduler_config.max_num_seqs, self.scheduler_config.max_num_seqs,
self.scheduler_config.max_num_batched_tokens, self.scheduler_config.max_num_batched_tokens,
......
...@@ -46,7 +46,6 @@ def _torch_cuda_wrapper(): ...@@ -46,7 +46,6 @@ def _torch_cuda_wrapper():
if supports_xpu_graph(): if supports_xpu_graph():
torch.cuda.graph = torch.xpu.graph torch.cuda.graph = torch.xpu.graph
torch.cuda.CUDAGraph = torch.xpu.XPUGraph torch.cuda.CUDAGraph = torch.xpu.XPUGraph
torch.cuda.empty_cache = torch.xpu.empty_cache
yield yield
finally: finally:
pass pass
...@@ -62,7 +62,7 @@ class XPUWorker(Worker): ...@@ -62,7 +62,7 @@ class XPUWorker(Worker):
self.device = torch.device(f"xpu:{self.local_rank}") self.device = torch.device(f"xpu:{self.local_rank}")
current_platform.set_device(self.device) current_platform.set_device(self.device)
current_platform.check_if_supports_dtype(self.model_config.dtype) current_platform.check_if_supports_dtype(self.model_config.dtype)
torch.xpu.empty_cache() torch.accelerator.empty_cache()
self.init_gpu_memory = torch.xpu.get_device_properties( self.init_gpu_memory = torch.xpu.get_device_properties(
self.local_rank self.local_rank
).total_memory ).total_memory
...@@ -90,7 +90,7 @@ class XPUWorker(Worker): ...@@ -90,7 +90,7 @@ class XPUWorker(Worker):
# Now take memory snapshot after NCCL is initialized # Now take memory snapshot after NCCL is initialized
gc.collect() gc.collect()
torch.xpu.empty_cache() torch.accelerator.empty_cache()
# take current memory snapshot # take current memory snapshot
self.init_snapshot = init_snapshot = MemorySnapshot(device=self.device) self.init_snapshot = init_snapshot = MemorySnapshot(device=self.device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment