Unverified Commit 16d2ad1d authored by Kunshang Ji's avatar Kunshang Ji Committed by GitHub
Browse files

[Hardware] Replace `torch.cuda.empty_cache` with `torch.accelerator.empty_cache` (#30681)


Signed-off-by: default avatarKunshang Ji <kunshang.ji@intel.com>
Signed-off-by: default avatarKunshang Ji <jikunshang95@gmail.com>
Co-authored-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 5dc35387
......@@ -530,7 +530,7 @@ def test_logprobs_mode(logprobs_mode: LogprobsMode):
assert positive_values > 0
finally:
del llm
torch.cuda.empty_cache()
torch.accelerator.empty_cache()
cleanup_dist_env_and_memory()
......@@ -1065,7 +1065,7 @@ def test_spec_decode_logprobs(
for logprobs in output.logprobs:
ref_logprobs.extend(logprobs.values())
del ref_llm
torch.cuda.empty_cache()
torch.accelerator.empty_cache()
cleanup_dist_env_and_memory()
# Run spec decode LLM.
......@@ -1095,7 +1095,7 @@ def test_spec_decode_logprobs(
for logprobs in output.logprobs:
spec_logprobs.extend(logprobs.values())
del spec_llm
torch.cuda.empty_cache()
torch.accelerator.empty_cache()
cleanup_dist_env_and_memory()
# Per-token logprobs are expected to be the same.
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import sys
import regex as re
# --------------------------------------------------------------------------- #
# Regex: match `torch.cuda.xxx` but allow `torch.accelerator.xxx`
# --------------------------------------------------------------------------- #
_TORCH_CUDA_PATTERNS = [
r"\btorch\.cuda\.empty_cache\b",
]
ALLOWED_FILES = {"vllm/platforms/", "vllm/device_allocator/"}
def scan_file(path: str) -> int:
with open(path, encoding="utf-8") as f:
content = f.read()
for pattern in _TORCH_CUDA_PATTERNS:
for match in re.finditer(pattern, content, re.MULTILINE):
# Calculate line number from match position
line_num = content[: match.start() + 1].count("\n") + 1
print(
f"{path}:{line_num}: "
"\033[91merror:\033[0m " # red color
"Found torch.cuda API call"
)
return 1
return 0
def main():
returncode = 0
for filename in sys.argv[1:]:
if any(filename.startswith(prefix) for prefix in ALLOWED_FILES):
continue
returncode |= scan_file(filename)
return returncode
if __name__ == "__main__":
sys.exit(main())
......@@ -260,7 +260,9 @@ class CUDAGraphWrapper:
# therefore, we only run gc for the first graph,
# and disable gc for the rest of the graphs.
stack.enter_context(patch("gc.collect", lambda: None))
stack.enter_context(patch("torch.cuda.empty_cache", lambda: None))
stack.enter_context(
patch("torch.accelerator.empty_cache", lambda: None)
)
if self.graph_pool is not None:
set_graph_pool_id(self.graph_pool)
......
......@@ -408,7 +408,7 @@ class ElasticEPScalingExecutor:
gc.collect()
torch.cuda.synchronize()
torch.cuda.empty_cache()
torch.accelerator.empty_cache()
unlock_workspace()
self.worker.compile_or_warm_up_model()
lock_workspace()
......
......@@ -1916,14 +1916,14 @@ def cleanup_dist_env_and_memory(shutdown_ray: bool = False):
gc.collect()
from vllm.platforms import current_platform
empty_cache = current_platform.empty_cache
if empty_cache is not None:
empty_cache()
try:
if not current_platform.is_cpu():
if not current_platform.is_cpu():
torch.accelerator.empty_cache()
try:
torch._C._host_emptyCache()
except AttributeError:
logger.warning("torch._C._host_emptyCache() only available in Pytorch >=2.5")
except AttributeError:
logger.warning(
"torch._C._host_emptyCache() only available in Pytorch >=2.5"
)
def in_the_same_node_as(
......
......@@ -200,7 +200,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
):
num_pad = 256 // weight.element_size()
weight = F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad]
torch.cuda.empty_cache()
torch.accelerator.empty_cache()
return weight
......
......@@ -961,7 +961,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
# secondly, process mxfp weights
if self.emulate:
torch.cuda.empty_cache()
torch.accelerator.empty_cache()
return
from aiter.utility.fp4_utils import e8m0_shuffle
......@@ -995,7 +995,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False)
layer.w13_weight.is_shuffled = True
layer.w2_weight.is_shuffled = True
torch.cuda.empty_cache()
torch.accelerator.empty_cache()
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
......@@ -1116,7 +1116,7 @@ class QuarkOCP_MX_MoEMethod_OSS(QuarkOCP_MX_MoEMethod):
del layer.w2_weight
layer.w13_weight = None
layer.w2_weight = None
torch.cuda.empty_cache()
torch.accelerator.empty_cache()
if self.static_input_scales:
if layer.w13_input_scale is None or layer.w2_input_scale is None:
......
......@@ -1407,7 +1407,7 @@ def _maybe_pad_fp8_weight(weight: torch.Tensor) -> torch.Tensor:
import torch.nn.functional as F
weight = F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad]
torch.cuda.empty_cache()
torch.accelerator.empty_cache()
return weight
......
......@@ -811,7 +811,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
**stacked_quant_state_dict,
}
self._bind_quant_states_to_params(model, stacked_quant_state_dict)
torch.cuda.empty_cache()
torch.accelerator.empty_cache()
def download_model(self, model_config: ModelConfig) -> None:
self._prepare_weights(model_config.model, model_config.revision)
......@@ -96,7 +96,7 @@ class MemorySnapshot:
# rather than `torch.cuda.memory_reserved()` .
# After `torch.cuda.reset_peak_memory_stats()`,
# `torch.cuda.memory_reserved()` will keep growing, and only shrink
# when we call `torch.cuda.empty_cache()` or OOM happens.
# when we call `torch.accelerator.empty_cache()` or OOM happens.
self.torch_peak = current_platform.memory_stats(device).get(
"allocated_bytes.all.peak", 0
)
......@@ -250,7 +250,7 @@ def memory_profiling(
until after profiling to get (c.).
"""
gc.collect()
current_platform.empty_cache()
torch.accelerator.empty_cache()
current_platform.reset_peak_memory_stats(baseline_snapshot.device_)
result = MemoryProfilingResult(
......@@ -264,7 +264,7 @@ def memory_profiling(
yield result
gc.collect()
current_platform.empty_cache()
torch.accelerator.empty_cache()
result.after_profile.measure()
......
......@@ -1036,4 +1036,4 @@ def apply_top_k_top_p_triton(
def reset_buffer_cache():
_TRITON_BUFFER_CACHE.clear()
_TRITON_TABLE_CACHE.clear()
torch.cuda.empty_cache()
torch.accelerator.empty_cache()
......@@ -496,7 +496,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
start_time = time.perf_counter()
gc.collect()
torch.cuda.empty_cache()
torch.accelerator.empty_cache()
start_free_gpu_memory = torch.cuda.mem_get_info()[0]
with self.maybe_setup_dummy_loras(self.lora_config):
......
......@@ -278,7 +278,7 @@ class Worker(WorkerBase):
# Now take memory snapshot after NCCL is initialized
gc.collect()
torch.cuda.empty_cache()
torch.accelerator.empty_cache()
# take current memory snapshot
self.init_snapshot = init_snapshot = MemorySnapshot(device=self.device)
......@@ -585,7 +585,7 @@ class Worker(WorkerBase):
# sampling related tensors of max possible shape to avoid memory
# fragmentation issue.
# NOTE: This is called after `capture_model` on purpose to prevent
# memory buffers from being cleared by `torch.cuda.empty_cache`.
# memory buffers from being cleared by `torch.accelerator.empty_cache`.
max_num_reqs = min(
self.scheduler_config.max_num_seqs,
self.scheduler_config.max_num_batched_tokens,
......
......@@ -46,7 +46,6 @@ def _torch_cuda_wrapper():
if supports_xpu_graph():
torch.cuda.graph = torch.xpu.graph
torch.cuda.CUDAGraph = torch.xpu.XPUGraph
torch.cuda.empty_cache = torch.xpu.empty_cache
yield
finally:
pass
......@@ -62,7 +62,7 @@ class XPUWorker(Worker):
self.device = torch.device(f"xpu:{self.local_rank}")
current_platform.set_device(self.device)
current_platform.check_if_supports_dtype(self.model_config.dtype)
torch.xpu.empty_cache()
torch.accelerator.empty_cache()
self.init_gpu_memory = torch.xpu.get_device_properties(
self.local_rank
).total_memory
......@@ -90,7 +90,7 @@ class XPUWorker(Worker):
# Now take memory snapshot after NCCL is initialized
gc.collect()
torch.xpu.empty_cache()
torch.accelerator.empty_cache()
# take current memory snapshot
self.init_snapshot = init_snapshot = MemorySnapshot(device=self.device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment