Unverified Commit 16d2ad1d authored by Kunshang Ji's avatar Kunshang Ji Committed by GitHub
Browse files

[Hardware] Replace `torch.cuda.empty_cache` with `torch.accelerator.empty_cache` (#30681)


Signed-off-by: default avatarKunshang Ji <kunshang.ji@intel.com>
Signed-off-by: default avatarKunshang Ji <jikunshang95@gmail.com>
Co-authored-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 5dc35387
...@@ -127,6 +127,13 @@ repos: ...@@ -127,6 +127,13 @@ repos:
language: python language: python
types: [python] types: [python]
additional_dependencies: [regex] additional_dependencies: [regex]
# prevent use torch.cuda APIs
- id: check-torch-cuda-call
name: "Prevent new 'torch.cuda' APIs call"
entry: python tools/pre_commit/check_torch_cuda.py
language: python
types: [python]
additional_dependencies: [regex]
- id: validate-config - id: validate-config
name: Validate configuration has default values and that each field has a docstring name: Validate configuration has default values and that each field has a docstring
entry: python tools/pre_commit/validate_config.py entry: python tools/pre_commit/validate_config.py
......
...@@ -102,7 +102,7 @@ def reset_memory_stats(): ...@@ -102,7 +102,7 @@ def reset_memory_stats():
"""Reset peak memory statistics.""" """Reset peak memory statistics."""
reset_buffer_cache() reset_buffer_cache()
torch.cuda.reset_peak_memory_stats() torch.cuda.reset_peak_memory_stats()
torch.cuda.empty_cache() torch.accelerator.empty_cache()
gc.collect() gc.collect()
......
...@@ -54,7 +54,7 @@ def clear_triton_cache(): ...@@ -54,7 +54,7 @@ def clear_triton_cache():
# Clear CUDA memory cache # Clear CUDA memory cache
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() torch.accelerator.empty_cache()
# Try to clear Triton's runtime cache # Try to clear Triton's runtime cache
try: try:
......
...@@ -104,7 +104,7 @@ def run_benchmark( ...@@ -104,7 +104,7 @@ def run_benchmark(
# free tensors to mitigate OOM when sweeping # free tensors to mitigate OOM when sweeping
del key, value, key_cache, value_cache, slot_mapping del key, value, key_cache, value_cache, slot_mapping
torch.cuda.empty_cache() torch.accelerator.empty_cache()
return lat return lat
......
...@@ -129,7 +129,7 @@ def run_benchmark( ...@@ -129,7 +129,7 @@ def run_benchmark(
# free tensors to mitigate OOM when sweeping # free tensors to mitigate OOM when sweeping
del key, value, key_cache, value_cache, slot_mapping del key, value, key_cache, value_cache, slot_mapping
torch.cuda.empty_cache() torch.accelerator.empty_cache()
return lat return lat
......
...@@ -120,7 +120,7 @@ def main(): ...@@ -120,7 +120,7 @@ def main():
# Clean up the GPU memory for the next test # Clean up the GPU memory for the next test
del engine del engine
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.accelerator.empty_cache()
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -159,7 +159,7 @@ class RayTrainingActor: ...@@ -159,7 +159,7 @@ class RayTrainingActor:
s.close() s.close()
del buffer del buffer
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.accelerator.empty_cache()
# Ray manages four GPUs. # Ray manages four GPUs.
......
...@@ -150,7 +150,7 @@ class ColocateWorkerExtension: ...@@ -150,7 +150,7 @@ class ColocateWorkerExtension:
socket.close() socket.close()
del buffer del buffer
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.accelerator.empty_cache()
def report_device_id(self) -> str: def report_device_id(self) -> str:
from vllm.platforms import current_platform from vllm.platforms import current_platform
......
...@@ -99,7 +99,7 @@ def test_dynamic_shapes_compilation( ...@@ -99,7 +99,7 @@ def test_dynamic_shapes_compilation(
# Clean up GPU memory # Clean up GPU memory
del model del model
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.accelerator.empty_cache()
torch.cuda.synchronize() torch.cuda.synchronize()
print("GPU memory cleared") print("GPU memory cleared")
......
...@@ -1533,7 +1533,7 @@ def clean_gpu_memory_between_tests(): ...@@ -1533,7 +1533,7 @@ def clean_gpu_memory_between_tests():
# Clean up GPU memory after the test # Clean up GPU memory after the test
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() torch.accelerator.empty_cache()
gc.collect() gc.collect()
......
...@@ -24,7 +24,7 @@ LORA_PATH = "davzoku/finqa_adapter_1b" ...@@ -24,7 +24,7 @@ LORA_PATH = "davzoku/finqa_adapter_1b"
def _cleanup(): def _cleanup():
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.accelerator.empty_cache()
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
......
...@@ -273,7 +273,7 @@ def test_causal_conv1d_varlen( ...@@ -273,7 +273,7 @@ def test_causal_conv1d_varlen(
batch, with_padding, dim, seqlen, width, has_bias, silu_activation, itype batch, with_padding, dim, seqlen, width, has_bias, silu_activation, itype
): ):
device = "cuda" device = "cuda"
torch.cuda.empty_cache() torch.accelerator.empty_cache()
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3) rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
if itype == torch.bfloat16: if itype == torch.bfloat16:
rtol, atol = 1e-2, 5e-2 rtol, atol = 1e-2, 5e-2
......
...@@ -769,7 +769,7 @@ def test_mixtral_moe( ...@@ -769,7 +769,7 @@ def test_mixtral_moe(
requires_grad=False, requires_grad=False,
) )
torch.cuda.synchronize() torch.cuda.synchronize()
torch.cuda.empty_cache() torch.accelerator.empty_cache()
# FIXME (zyongye) fix this after we move self.kernel # FIXME (zyongye) fix this after we move self.kernel
# assignment in FusedMoE.__init__ # assignment in FusedMoE.__init__
......
...@@ -178,7 +178,7 @@ def test_load_without_tensorizer_load_format(vllm_runner, capfd, model_ref): ...@@ -178,7 +178,7 @@ def test_load_without_tensorizer_load_format(vllm_runner, capfd, model_ref):
finally: finally:
del model del model
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.accelerator.empty_cache()
def test_raise_value_error_on_invalid_load_format(vllm_runner, capfd, model_ref): def test_raise_value_error_on_invalid_load_format(vllm_runner, capfd, model_ref):
...@@ -200,7 +200,7 @@ def test_raise_value_error_on_invalid_load_format(vllm_runner, capfd, model_ref) ...@@ -200,7 +200,7 @@ def test_raise_value_error_on_invalid_load_format(vllm_runner, capfd, model_ref)
finally: finally:
del model del model
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.accelerator.empty_cache()
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires 2 GPUs") @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires 2 GPUs")
...@@ -283,7 +283,7 @@ def test_vllm_tensorized_model_has_same_outputs( ...@@ -283,7 +283,7 @@ def test_vllm_tensorized_model_has_same_outputs(
model_ref, vllm_runner, tmp_path, model_path model_ref, vllm_runner, tmp_path, model_path
): ):
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.accelerator.empty_cache()
config = TensorizerConfig(tensorizer_uri=str(model_path)) config = TensorizerConfig(tensorizer_uri=str(model_path))
args = EngineArgs(model=model_ref) args = EngineArgs(model=model_ref)
......
...@@ -49,7 +49,7 @@ def test_gc(): ...@@ -49,7 +49,7 @@ def test_gc():
del llm del llm
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.accelerator.empty_cache()
# The memory allocated for model and KV cache should be released. # The memory allocated for model and KV cache should be released.
# The memory allocated for PyTorch and others should be less than 50MB. # The memory allocated for PyTorch and others should be less than 50MB.
......
...@@ -125,7 +125,7 @@ def test_no_sync_with_spec_decode( ...@@ -125,7 +125,7 @@ def test_no_sync_with_spec_decode(
assert len(outputs[0].outputs[0].text) > 0 assert len(outputs[0].outputs[0].text) > 0
del llm del llm
torch.cuda.empty_cache() torch.accelerator.empty_cache()
cleanup_dist_env_and_memory() cleanup_dist_env_and_memory()
sync_tracker.assert_no_sync() sync_tracker.assert_no_sync()
...@@ -95,7 +95,7 @@ def test_batch_inference_correctness( ...@@ -95,7 +95,7 @@ def test_batch_inference_correctness(
prompts, sampling_params, lora_request=lora_request prompts, sampling_params, lora_request=lora_request
) )
del ref_llm del ref_llm
torch.cuda.empty_cache() torch.accelerator.empty_cache()
cleanup_dist_env_and_memory() cleanup_dist_env_and_memory()
lora_spec_llm = LLM( lora_spec_llm = LLM(
...@@ -135,5 +135,5 @@ def test_batch_inference_correctness( ...@@ -135,5 +135,5 @@ def test_batch_inference_correctness(
print(f"match ratio: {matches}/{len(ref_outputs)}") print(f"match ratio: {matches}/{len(ref_outputs)}")
assert matches > int(0.90 * len(ref_outputs)) assert matches > int(0.90 * len(ref_outputs))
del lora_spec_llm del lora_spec_llm
torch.cuda.empty_cache() torch.accelerator.empty_cache()
cleanup_dist_env_and_memory() cleanup_dist_env_and_memory()
...@@ -440,7 +440,7 @@ def _run_ref_mamba_state_worker(): ...@@ -440,7 +440,7 @@ def _run_ref_mamba_state_worker():
torch.save(cpu_state_ref, "mamba_kv_cache_dict_ref.pth") torch.save(cpu_state_ref, "mamba_kv_cache_dict_ref.pth")
mamba_kv_cache_dict.clear() mamba_kv_cache_dict.clear()
del engine del engine
torch.cuda.empty_cache() torch.accelerator.empty_cache()
cleanup_dist_env_and_memory() cleanup_dist_env_and_memory()
except Exception: except Exception:
traceback.print_exc() traceback.print_exc()
...@@ -805,5 +805,5 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch): ...@@ -805,5 +805,5 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch):
check_mamba_state_equal(mamba_state_ref, mamba_kv_cache_dict, keys_to_check) check_mamba_state_equal(mamba_state_ref, mamba_kv_cache_dict, keys_to_check)
mamba_kv_cache_dict.clear() mamba_kv_cache_dict.clear()
del engine del engine
torch.cuda.empty_cache() torch.accelerator.empty_cache()
cleanup_dist_env_and_memory() cleanup_dist_env_and_memory()
...@@ -179,7 +179,7 @@ def test_ngram_and_suffix_correctness( ...@@ -179,7 +179,7 @@ def test_ngram_and_suffix_correctness(
) )
evaluate_llm_for_gsm8k(spec_llm) evaluate_llm_for_gsm8k(spec_llm)
del spec_llm del spec_llm
torch.cuda.empty_cache() torch.accelerator.empty_cache()
cleanup_dist_env_and_memory() cleanup_dist_env_and_memory()
...@@ -240,7 +240,7 @@ def test_suffix_decoding_acceptance( ...@@ -240,7 +240,7 @@ def test_suffix_decoding_acceptance(
assert last_accept_rate > 0.80 assert last_accept_rate > 0.80
del spec_llm del spec_llm
torch.cuda.empty_cache() torch.accelerator.empty_cache()
cleanup_dist_env_and_memory() cleanup_dist_env_and_memory()
...@@ -307,14 +307,14 @@ def test_speculators_model_integration( ...@@ -307,14 +307,14 @@ def test_speculators_model_integration(
verifier_model = spec_llm.llm_engine.vllm_config.model_config.model verifier_model = spec_llm.llm_engine.vllm_config.model_config.model
del spec_llm del spec_llm
torch.cuda.empty_cache() torch.accelerator.empty_cache()
cleanup_dist_env_and_memory() cleanup_dist_env_and_memory()
# Second run: Reference without speculative decoding # Second run: Reference without speculative decoding
ref_llm = LLM(model=verifier_model, max_model_len=4096) ref_llm = LLM(model=verifier_model, max_model_len=4096)
ref_outputs = ref_llm.chat(test_prompts, sampling_config) ref_outputs = ref_llm.chat(test_prompts, sampling_config)
del ref_llm del ref_llm
torch.cuda.empty_cache() torch.accelerator.empty_cache()
cleanup_dist_env_and_memory() cleanup_dist_env_and_memory()
# Compare outputs # Compare outputs
...@@ -410,7 +410,7 @@ def _run_eagle_correctness( ...@@ -410,7 +410,7 @@ def _run_eagle_correctness(
) )
ref_outputs = ref_llm.chat(test_prompts, sampling_config) ref_outputs = ref_llm.chat(test_prompts, sampling_config)
del ref_llm del ref_llm
torch.cuda.empty_cache() torch.accelerator.empty_cache()
cleanup_dist_env_and_memory() cleanup_dist_env_and_memory()
spec_llm = LLM( spec_llm = LLM(
...@@ -445,7 +445,7 @@ def _run_eagle_correctness( ...@@ -445,7 +445,7 @@ def _run_eagle_correctness(
assert matches > int(0.6 * len(ref_outputs)) assert matches > int(0.6 * len(ref_outputs))
del spec_llm del spec_llm
torch.cuda.empty_cache() torch.accelerator.empty_cache()
cleanup_dist_env_and_memory() cleanup_dist_env_and_memory()
...@@ -715,7 +715,7 @@ def test_mtp_correctness( ...@@ -715,7 +715,7 @@ def test_mtp_correctness(
ref_llm, expected_accuracy_threshold=expected_accuracy_threshold ref_llm, expected_accuracy_threshold=expected_accuracy_threshold
) )
del ref_llm del ref_llm
torch.cuda.empty_cache() torch.accelerator.empty_cache()
cleanup_dist_env_and_memory() cleanup_dist_env_and_memory()
spec_llm = LLM( spec_llm = LLM(
...@@ -747,7 +747,7 @@ def test_mtp_correctness( ...@@ -747,7 +747,7 @@ def test_mtp_correctness(
# Upon failure, inspect the outputs to check for inaccuracy. # Upon failure, inspect the outputs to check for inaccuracy.
assert matches > int(MTP_SIMILARITY_RATE * len(ref_outputs)) assert matches > int(MTP_SIMILARITY_RATE * len(ref_outputs))
del spec_llm del spec_llm
torch.cuda.empty_cache() torch.accelerator.empty_cache()
cleanup_dist_env_and_memory() cleanup_dist_env_and_memory()
...@@ -952,7 +952,7 @@ def assert_draft_model_correctness(args: ArgsTest): ...@@ -952,7 +952,7 @@ def assert_draft_model_correctness(args: ArgsTest):
) )
del spec_llm # CLEANUP del spec_llm # CLEANUP
torch.cuda.empty_cache() torch.accelerator.empty_cache()
cleanup_dist_env_and_memory() cleanup_dist_env_and_memory()
print( print(
......
...@@ -857,7 +857,7 @@ def test_structured_output_batched_with_non_structured_outputs_requests( ...@@ -857,7 +857,7 @@ def test_structured_output_batched_with_non_structured_outputs_requests(
# Free memory as soon as possible as failed assertions # Free memory as soon as possible as failed assertions
# will short circuit and not free up memory # will short circuit and not free up memory
del llm del llm
torch.cuda.empty_cache() torch.accelerator.empty_cache()
cleanup_dist_env_and_memory() cleanup_dist_env_and_memory()
for index, output in enumerate(outputs): for index, output in enumerate(outputs):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment