Unverified Commit 16d2ad1d authored by Kunshang Ji's avatar Kunshang Ji Committed by GitHub
Browse files

[Hardware] Replace `torch.cuda.empty_cache` with `torch.accelerator.empty_cache` (#30681)


Signed-off-by: default avatarKunshang Ji <kunshang.ji@intel.com>
Signed-off-by: default avatarKunshang Ji <jikunshang95@gmail.com>
Co-authored-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 5dc35387
......@@ -127,6 +127,13 @@ repos:
language: python
types: [python]
additional_dependencies: [regex]
# prevent use torch.cuda APIs
- id: check-torch-cuda-call
name: "Prevent new 'torch.cuda' APIs call"
entry: python tools/pre_commit/check_torch_cuda.py
language: python
types: [python]
additional_dependencies: [regex]
- id: validate-config
name: Validate configuration has default values and that each field has a docstring
entry: python tools/pre_commit/validate_config.py
......
......@@ -102,7 +102,7 @@ def reset_memory_stats():
"""Reset peak memory statistics."""
reset_buffer_cache()
torch.cuda.reset_peak_memory_stats()
torch.cuda.empty_cache()
torch.accelerator.empty_cache()
gc.collect()
......
......@@ -54,7 +54,7 @@ def clear_triton_cache():
# Clear CUDA memory cache
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.accelerator.empty_cache()
# Try to clear Triton's runtime cache
try:
......
......@@ -104,7 +104,7 @@ def run_benchmark(
# free tensors to mitigate OOM when sweeping
del key, value, key_cache, value_cache, slot_mapping
torch.cuda.empty_cache()
torch.accelerator.empty_cache()
return lat
......
......@@ -129,7 +129,7 @@ def run_benchmark(
# free tensors to mitigate OOM when sweeping
del key, value, key_cache, value_cache, slot_mapping
torch.cuda.empty_cache()
torch.accelerator.empty_cache()
return lat
......
......@@ -120,7 +120,7 @@ def main():
# Clean up the GPU memory for the next test
del engine
gc.collect()
torch.cuda.empty_cache()
torch.accelerator.empty_cache()
if __name__ == "__main__":
......
......@@ -159,7 +159,7 @@ class RayTrainingActor:
s.close()
del buffer
gc.collect()
torch.cuda.empty_cache()
torch.accelerator.empty_cache()
# Ray manages four GPUs.
......
......@@ -150,7 +150,7 @@ class ColocateWorkerExtension:
socket.close()
del buffer
gc.collect()
torch.cuda.empty_cache()
torch.accelerator.empty_cache()
def report_device_id(self) -> str:
from vllm.platforms import current_platform
......
......@@ -99,7 +99,7 @@ def test_dynamic_shapes_compilation(
# Clean up GPU memory
del model
gc.collect()
torch.cuda.empty_cache()
torch.accelerator.empty_cache()
torch.cuda.synchronize()
print("GPU memory cleared")
......
......@@ -1533,7 +1533,7 @@ def clean_gpu_memory_between_tests():
# Clean up GPU memory after the test
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.accelerator.empty_cache()
gc.collect()
......
......@@ -24,7 +24,7 @@ LORA_PATH = "davzoku/finqa_adapter_1b"
def _cleanup():
gc.collect()
torch.cuda.empty_cache()
torch.accelerator.empty_cache()
@pytest.fixture(autouse=True)
......
......@@ -273,7 +273,7 @@ def test_causal_conv1d_varlen(
batch, with_padding, dim, seqlen, width, has_bias, silu_activation, itype
):
device = "cuda"
torch.cuda.empty_cache()
torch.accelerator.empty_cache()
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
if itype == torch.bfloat16:
rtol, atol = 1e-2, 5e-2
......
......@@ -769,7 +769,7 @@ def test_mixtral_moe(
requires_grad=False,
)
torch.cuda.synchronize()
torch.cuda.empty_cache()
torch.accelerator.empty_cache()
# FIXME (zyongye) fix this after we move self.kernel
# assignment in FusedMoE.__init__
......
......@@ -178,7 +178,7 @@ def test_load_without_tensorizer_load_format(vllm_runner, capfd, model_ref):
finally:
del model
gc.collect()
torch.cuda.empty_cache()
torch.accelerator.empty_cache()
def test_raise_value_error_on_invalid_load_format(vllm_runner, capfd, model_ref):
......@@ -200,7 +200,7 @@ def test_raise_value_error_on_invalid_load_format(vllm_runner, capfd, model_ref)
finally:
del model
gc.collect()
torch.cuda.empty_cache()
torch.accelerator.empty_cache()
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires 2 GPUs")
......@@ -283,7 +283,7 @@ def test_vllm_tensorized_model_has_same_outputs(
model_ref, vllm_runner, tmp_path, model_path
):
gc.collect()
torch.cuda.empty_cache()
torch.accelerator.empty_cache()
config = TensorizerConfig(tensorizer_uri=str(model_path))
args = EngineArgs(model=model_ref)
......
......@@ -49,7 +49,7 @@ def test_gc():
del llm
gc.collect()
torch.cuda.empty_cache()
torch.accelerator.empty_cache()
# The memory allocated for model and KV cache should be released.
# The memory allocated for PyTorch and others should be less than 50MB.
......
......@@ -125,7 +125,7 @@ def test_no_sync_with_spec_decode(
assert len(outputs[0].outputs[0].text) > 0
del llm
torch.cuda.empty_cache()
torch.accelerator.empty_cache()
cleanup_dist_env_and_memory()
sync_tracker.assert_no_sync()
......@@ -95,7 +95,7 @@ def test_batch_inference_correctness(
prompts, sampling_params, lora_request=lora_request
)
del ref_llm
torch.cuda.empty_cache()
torch.accelerator.empty_cache()
cleanup_dist_env_and_memory()
lora_spec_llm = LLM(
......@@ -135,5 +135,5 @@ def test_batch_inference_correctness(
print(f"match ratio: {matches}/{len(ref_outputs)}")
assert matches > int(0.90 * len(ref_outputs))
del lora_spec_llm
torch.cuda.empty_cache()
torch.accelerator.empty_cache()
cleanup_dist_env_and_memory()
......@@ -440,7 +440,7 @@ def _run_ref_mamba_state_worker():
torch.save(cpu_state_ref, "mamba_kv_cache_dict_ref.pth")
mamba_kv_cache_dict.clear()
del engine
torch.cuda.empty_cache()
torch.accelerator.empty_cache()
cleanup_dist_env_and_memory()
except Exception:
traceback.print_exc()
......@@ -805,5 +805,5 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch):
check_mamba_state_equal(mamba_state_ref, mamba_kv_cache_dict, keys_to_check)
mamba_kv_cache_dict.clear()
del engine
torch.cuda.empty_cache()
torch.accelerator.empty_cache()
cleanup_dist_env_and_memory()
......@@ -179,7 +179,7 @@ def test_ngram_and_suffix_correctness(
)
evaluate_llm_for_gsm8k(spec_llm)
del spec_llm
torch.cuda.empty_cache()
torch.accelerator.empty_cache()
cleanup_dist_env_and_memory()
......@@ -240,7 +240,7 @@ def test_suffix_decoding_acceptance(
assert last_accept_rate > 0.80
del spec_llm
torch.cuda.empty_cache()
torch.accelerator.empty_cache()
cleanup_dist_env_and_memory()
......@@ -307,14 +307,14 @@ def test_speculators_model_integration(
verifier_model = spec_llm.llm_engine.vllm_config.model_config.model
del spec_llm
torch.cuda.empty_cache()
torch.accelerator.empty_cache()
cleanup_dist_env_and_memory()
# Second run: Reference without speculative decoding
ref_llm = LLM(model=verifier_model, max_model_len=4096)
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
del ref_llm
torch.cuda.empty_cache()
torch.accelerator.empty_cache()
cleanup_dist_env_and_memory()
# Compare outputs
......@@ -410,7 +410,7 @@ def _run_eagle_correctness(
)
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
del ref_llm
torch.cuda.empty_cache()
torch.accelerator.empty_cache()
cleanup_dist_env_and_memory()
spec_llm = LLM(
......@@ -445,7 +445,7 @@ def _run_eagle_correctness(
assert matches > int(0.6 * len(ref_outputs))
del spec_llm
torch.cuda.empty_cache()
torch.accelerator.empty_cache()
cleanup_dist_env_and_memory()
......@@ -715,7 +715,7 @@ def test_mtp_correctness(
ref_llm, expected_accuracy_threshold=expected_accuracy_threshold
)
del ref_llm
torch.cuda.empty_cache()
torch.accelerator.empty_cache()
cleanup_dist_env_and_memory()
spec_llm = LLM(
......@@ -747,7 +747,7 @@ def test_mtp_correctness(
# Upon failure, inspect the outputs to check for inaccuracy.
assert matches > int(MTP_SIMILARITY_RATE * len(ref_outputs))
del spec_llm
torch.cuda.empty_cache()
torch.accelerator.empty_cache()
cleanup_dist_env_and_memory()
......@@ -952,7 +952,7 @@ def assert_draft_model_correctness(args: ArgsTest):
)
del spec_llm # CLEANUP
torch.cuda.empty_cache()
torch.accelerator.empty_cache()
cleanup_dist_env_and_memory()
print(
......
......@@ -857,7 +857,7 @@ def test_structured_output_batched_with_non_structured_outputs_requests(
# Free memory as soon as possible as failed assertions
# will short circuit and not free up memory
del llm
torch.cuda.empty_cache()
torch.accelerator.empty_cache()
cleanup_dist_env_and_memory()
for index, output in enumerate(outputs):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment