Unverified Commit 66a22096 authored by Kunshang Ji's avatar Kunshang Ji Committed by GitHub
Browse files

[Hardware] Replace `torch.cuda.synchronize()` api with `torch.accelerator.synchronize` (#36085)


Signed-off-by: default avatarKunshang Ji <kunshang.ji@intel.com>
parent 0bfa229b
...@@ -34,7 +34,7 @@ def do_profile( ...@@ -34,7 +34,7 @@ def do_profile(
record_shapes=True, record_shapes=True,
) as tprof: ) as tprof:
fn(**fn_kwargs) fn(**fn_kwargs)
torch.cuda.synchronize(torch.cuda.current_device()) torch.accelerator.synchronize(torch.cuda.current_device())
# TODO (varun): Add a descriptive trace file name # TODO (varun): Add a descriptive trace file name
tprof.export_chrome_trace( tprof.export_chrome_trace(
......
...@@ -318,8 +318,8 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed, monkeypatch) ...@@ -318,8 +318,8 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed, monkeypatch)
out = deep_gemm_moe_fp8_fn( out = deep_gemm_moe_fp8_fn(
a, w1, w2, w1_s, w2_s, topk_weights, topk_ids a, w1, w2, w1_s, w2_s, topk_weights, topk_ids
) )
torch.cuda.synchronize() torch.accelerator.synchronize()
graph.replay() graph.replay()
torch.cuda.synchronize() torch.accelerator.synchronize()
torch.testing.assert_close(out, ref_out, atol=0.035, rtol=0.035) torch.testing.assert_close(out, ref_out, atol=0.035, rtol=0.035)
...@@ -399,9 +399,9 @@ def test_cutlass_moe_8_bit_cuda_graph( ...@@ -399,9 +399,9 @@ def test_cutlass_moe_8_bit_cuda_graph(
mt, topk_weights, topk_ids, per_act_token, per_out_ch mt, topk_weights, topk_ids, per_act_token, per_out_ch
) )
torch.cuda.synchronize() torch.accelerator.synchronize()
graph.replay() graph.replay()
torch.cuda.synchronize() torch.accelerator.synchronize()
torch.testing.assert_close(triton_output, cutlass_output, atol=9e-2, rtol=1e-2) torch.testing.assert_close(triton_output, cutlass_output, atol=9e-2, rtol=1e-2)
......
...@@ -272,9 +272,9 @@ def run_moe_test( ...@@ -272,9 +272,9 @@ def run_moe_test(
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
expert_map=expert_map, expert_map=expert_map,
) )
torch.cuda.synchronize() torch.accelerator.synchronize()
graph.replay() graph.replay()
torch.cuda.synchronize() torch.accelerator.synchronize()
torch.testing.assert_close(test_output, baseline_output, atol=atol, rtol=rtol) torch.testing.assert_close(test_output, baseline_output, atol=atol, rtol=rtol)
...@@ -768,7 +768,7 @@ def test_mixtral_moe( ...@@ -768,7 +768,7 @@ def test_mixtral_moe(
F.pad(vllm_moe.experts.w2_weight, (0, 128), "constant", 0)[..., 0:-128], F.pad(vllm_moe.experts.w2_weight, (0, 128), "constant", 0)[..., 0:-128],
requires_grad=False, requires_grad=False,
) )
torch.cuda.synchronize() torch.accelerator.synchronize()
torch.accelerator.empty_cache() torch.accelerator.empty_cache()
# FIXME (zyongye) fix this after we move self.kernel # FIXME (zyongye) fix this after we move self.kernel
......
...@@ -122,7 +122,7 @@ def test_gptq_allspark_gemm_ampere(mnk_factors, group_size, has_zp, dtype): ...@@ -122,7 +122,7 @@ def test_gptq_allspark_gemm_ampere(mnk_factors, group_size, has_zp, dtype):
) )
output_ref = torch.matmul(input, w_ref) output_ref = torch.matmul(input, w_ref)
torch.cuda.synchronize() torch.accelerator.synchronize()
max_diff = compute_max_diff(output, output_ref) max_diff = compute_max_diff(output, output_ref)
assert max_diff < 0.04 assert max_diff < 0.04
...@@ -269,7 +269,7 @@ def test_cutlass_w4a8_moe_mm_end_to_end(shape, random_zero): ...@@ -269,7 +269,7 @@ def test_cutlass_w4a8_moe_mm_end_to_end(shape, random_zero):
setup.c_strides, setup.c_strides,
setup.group_scale_strides, setup.group_scale_strides,
) )
torch.cuda.synchronize() torch.accelerator.synchronize()
out_ref = compute_moe_reference_output(setup) out_ref = compute_moe_reference_output(setup)
torch.testing.assert_close(setup.out, out_ref, rtol=1e-2, atol=1e-2) torch.testing.assert_close(setup.out, out_ref, rtol=1e-2, atol=1e-2)
......
...@@ -260,7 +260,7 @@ def test_gptq_marlin_repack( ...@@ -260,7 +260,7 @@ def test_gptq_marlin_repack(
marlin_q_w_2 = ops.gptq_marlin_repack( marlin_q_w_2 = ops.gptq_marlin_repack(
q_w_gptq, sort_indices, size_k, size_n, quant_type.size_bits, is_a_8bit q_w_gptq, sort_indices, size_k, size_n, quant_type.size_bits, is_a_8bit
) )
torch.cuda.synchronize() torch.accelerator.synchronize()
torch.testing.assert_close(marlin_q_w_1, marlin_q_w_2) torch.testing.assert_close(marlin_q_w_1, marlin_q_w_2)
...@@ -308,7 +308,7 @@ def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, is_a_8bit, nk_factors): ...@@ -308,7 +308,7 @@ def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, is_a_8bit, nk_factors):
marlin_q_w_2 = ops.awq_marlin_repack( marlin_q_w_2 = ops.awq_marlin_repack(
q_w_awq, size_k, size_n, quant_type.size_bits, is_a_8bit q_w_awq, size_k, size_n, quant_type.size_bits, is_a_8bit
) )
torch.cuda.synchronize() torch.accelerator.synchronize()
torch.testing.assert_close(marlin_q_w_1, marlin_q_w_2) torch.testing.assert_close(marlin_q_w_1, marlin_q_w_2)
...@@ -564,7 +564,7 @@ def test_marlin_gemm_subset_input(): ...@@ -564,7 +564,7 @@ def test_marlin_gemm_subset_input():
) )
output_ref = torch.matmul(a_input, w_ref) output_ref = torch.matmul(a_input, w_ref)
torch.cuda.synchronize() torch.accelerator.synchronize()
max_diff = compute_max_diff(output, output_ref) max_diff = compute_max_diff(output, output_ref)
...@@ -613,7 +613,7 @@ def test_marlin_gemm_with_bias(size_m): ...@@ -613,7 +613,7 @@ def test_marlin_gemm_with_bias(size_m):
) )
output_ref = torch.matmul(a_input, w_ref) + b_bias.view(1, -1) output_ref = torch.matmul(a_input, w_ref) + b_bias.view(1, -1)
torch.cuda.synchronize() torch.accelerator.synchronize()
max_diff = compute_max_diff(output, output_ref) max_diff = compute_max_diff(output, output_ref)
......
...@@ -57,7 +57,7 @@ def test_gather_cache_oob(): ...@@ -57,7 +57,7 @@ def test_gather_cache_oob():
seq_starts, seq_starts,
) )
torch.cuda.synchronize() torch.accelerator.synchronize()
assert True assert True
......
...@@ -219,7 +219,7 @@ def _run_top_k_per_row_decode_test( ...@@ -219,7 +219,7 @@ def _run_top_k_per_row_decode_test(
top_k, top_k,
) )
torch.cuda.synchronize() torch.accelerator.synchronize()
# Run reference implementation # Run reference implementation
torch_indices = torch.empty((num_rows, top_k), dtype=torch.int32, device="cuda") torch_indices = torch.empty((num_rows, top_k), dtype=torch.int32, device="cuda")
......
...@@ -195,4 +195,4 @@ def test_models( ...@@ -195,4 +195,4 @@ def test_models(
# unit tests. On ROCm, when using AITER # unit tests. On ROCm, when using AITER
# the memory might not be deallocated completely # the memory might not be deallocated completely
# before running the next test case # before running the next test case
torch.cuda.synchronize() torch.accelerator.synchronize()
...@@ -196,7 +196,7 @@ def test_compressed_tensors_w8a8_logprobs( ...@@ -196,7 +196,7 @@ def test_compressed_tensors_w8a8_logprobs(
) )
if current_platform.is_rocm(): if current_platform.is_rocm():
torch.cuda.synchronize() torch.accelerator.synchronize()
def test_compressed_tensors_no_enforce_eager(vllm_runner): def test_compressed_tensors_no_enforce_eager(vllm_runner):
......
...@@ -9,6 +9,7 @@ import regex as re ...@@ -9,6 +9,7 @@ import regex as re
# --------------------------------------------------------------------------- # # --------------------------------------------------------------------------- #
_TORCH_CUDA_PATTERNS = [ _TORCH_CUDA_PATTERNS = [
r"\btorch\.cuda\.empty_cache\b", r"\btorch\.cuda\.empty_cache\b",
r"\btorch\.cuda\.synchronize\b",
] ]
ALLOWED_FILES = {"vllm/platforms/", "vllm/device_allocator/"} ALLOWED_FILES = {"vllm/platforms/", "vllm/device_allocator/"}
......
...@@ -217,7 +217,7 @@ class ElasticEPScalingExecutor: ...@@ -217,7 +217,7 @@ class ElasticEPScalingExecutor:
dp_group=standby_dp_group, dp_group=standby_dp_group,
expert_weights=model.expert_weights, expert_weights=model.expert_weights,
) )
torch.cuda.synchronize() torch.accelerator.synchronize()
def broadcast_expert_mapping(self) -> None: def broadcast_expert_mapping(self) -> None:
standby_dp_group = get_standby_dp_group() standby_dp_group = get_standby_dp_group()
...@@ -407,7 +407,7 @@ class ElasticEPScalingExecutor: ...@@ -407,7 +407,7 @@ class ElasticEPScalingExecutor:
reset_compile_wrapper(self.worker.model_runner.get_model()) reset_compile_wrapper(self.worker.model_runner.get_model())
gc.collect() gc.collect()
torch.cuda.synchronize() torch.accelerator.synchronize()
torch.accelerator.empty_cache() torch.accelerator.empty_cache()
unlock_workspace() unlock_workspace()
self.worker.compile_or_warm_up_model() self.worker.compile_or_warm_up_model()
...@@ -446,7 +446,7 @@ class ElasticEPScalingExecutor: ...@@ -446,7 +446,7 @@ class ElasticEPScalingExecutor:
eplb_state.rearrange(rank_mapping=rank_mapping) eplb_state.rearrange(rank_mapping=rank_mapping)
# NOTE(yongji): check whether we need to synchronize here # NOTE(yongji): check whether we need to synchronize here
torch.cuda.synchronize() torch.accelerator.synchronize()
# reset expert_rearrangement_step to ensure all ranks are synchronized # reset expert_rearrangement_step to ensure all ranks are synchronized
eplb_state.expert_rearrangement_step = 0 eplb_state.expert_rearrangement_step = 0
eplb_state.num_valid_physical_experts = ( eplb_state.num_valid_physical_experts = (
...@@ -491,7 +491,7 @@ class ElasticEPScalingExecutor: ...@@ -491,7 +491,7 @@ class ElasticEPScalingExecutor:
dp_group=dp_group, dp_group=dp_group,
expert_weights=model.expert_weights, expert_weights=model.expert_weights,
) )
torch.cuda.synchronize() torch.accelerator.synchronize()
def receive_expert_mapping(self) -> tuple[torch.Tensor, int, int]: def receive_expert_mapping(self) -> tuple[torch.Tensor, int, int]:
dp_group = get_dp_group() dp_group = get_dp_group()
......
...@@ -622,7 +622,7 @@ def rearrange_expert_weights_inplace( ...@@ -622,7 +622,7 @@ def rearrange_expert_weights_inplace(
# NOTE(bowen): We need this synchronize to run, but I don't know why. # NOTE(bowen): We need this synchronize to run, but I don't know why.
# If you figure out the reason, please let me know -- thank you! # If you figure out the reason, please let me know -- thank you!
torch.cuda.synchronize() torch.accelerator.synchronize()
old_global_expert_indices_cpu = old_global_expert_indices.cpu().numpy() old_global_expert_indices_cpu = old_global_expert_indices.cpu().numpy()
new_global_expert_indices_cpu = new_global_expert_indices.cpu().numpy() new_global_expert_indices_cpu = new_global_expert_indices.cpu().numpy()
......
...@@ -77,7 +77,7 @@ class CutlassW4A8LinearKernel(MPLinearKernel): ...@@ -77,7 +77,7 @@ class CutlassW4A8LinearKernel(MPLinearKernel):
def transform_w_q(x): def transform_w_q(x):
assert isinstance(x, BasevLLMParameter) assert isinstance(x, BasevLLMParameter)
convert_packed_uint4b8_to_signed_int4_inplace(x.data) convert_packed_uint4b8_to_signed_int4_inplace(x.data)
torch.cuda.synchronize() torch.accelerator.synchronize()
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0) permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
x.data = ops.cutlass_encode_and_reorder_int4b(x.data.t().contiguous().t()) x.data = ops.cutlass_encode_and_reorder_int4b(x.data.t().contiguous().t())
return x return x
......
...@@ -457,7 +457,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -457,7 +457,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
else: else:
self._dummy_pooler_run(hidden_states) self._dummy_pooler_run(hidden_states)
torch.cuda.synchronize() torch.accelerator.synchronize()
del hidden_states, sample_hidden_states del hidden_states, sample_hidden_states
gc.collect() gc.collect()
...@@ -525,7 +525,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -525,7 +525,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# to trigger JIT compilation. # to trigger JIT compilation.
if all("FLASHINFER" in b.get_name() for b in self.attn_backends.values()): if all("FLASHINFER" in b.get_name() for b in self.attn_backends.values()):
self._dummy_run(self.max_num_tokens, skip_attn=False) self._dummy_run(self.max_num_tokens, skip_attn=False)
torch.cuda.synchronize() torch.accelerator.synchronize()
def finish_requests(self, scheduler_output: SchedulerOutput) -> None: def finish_requests(self, scheduler_output: SchedulerOutput) -> None:
finished_req_ids = scheduler_output.finished_req_ids finished_req_ids = scheduler_output.finished_req_ids
......
...@@ -102,4 +102,4 @@ def warmup_kernels(model_runner: GPUModelRunner) -> None: ...@@ -102,4 +102,4 @@ def warmup_kernels(model_runner: GPUModelRunner) -> None:
cleanup_output.finished_req_ids = set(req_ids) cleanup_output.finished_req_ids = set(req_ids)
model_runner.execute_model(cleanup_output) model_runner.execute_model(cleanup_output)
model_runner.kv_connector.set_disabled(False) model_runner.kv_connector.set_disabled(False)
torch.cuda.synchronize() torch.accelerator.synchronize()
...@@ -928,7 +928,7 @@ class GPUModelRunner( ...@@ -928,7 +928,7 @@ class GPUModelRunner(
# Note: used for model runner override. # Note: used for model runner override.
def _sync_device(self) -> None: def _sync_device(self) -> None:
torch.cuda.synchronize() torch.accelerator.synchronize()
def _update_states(self, scheduler_output: "SchedulerOutput") -> None: def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
"""Update the cached states and the persistent batch with the scheduler """Update the cached states and the persistent batch with the scheduler
...@@ -5345,7 +5345,7 @@ class GPUModelRunner( ...@@ -5345,7 +5345,7 @@ class GPUModelRunner(
cudagraph_runtime_mode=runtime_mode, cudagraph_runtime_mode=runtime_mode,
) )
torch.cuda.synchronize() torch.accelerator.synchronize()
end_free_gpu_memory = torch.cuda.mem_get_info()[0] end_free_gpu_memory = torch.cuda.mem_get_info()[0]
# Disable cudagraph capturing globally, so any unexpected cudagraph # Disable cudagraph capturing globally, so any unexpected cudagraph
...@@ -6266,13 +6266,13 @@ class GPUModelRunner( ...@@ -6266,13 +6266,13 @@ class GPUModelRunner(
group_refs = group_lora_refs[current_item_idx : current_item_idx + num_items] group_refs = group_lora_refs[current_item_idx : current_item_idx + num_items]
group_request_ids = {req_id for req_id, _ in group_refs} group_request_ids = {req_id for req_id, _ in group_refs}
torch.cuda.synchronize() torch.accelerator.synchronize()
start_time = time.perf_counter() start_time = time.perf_counter()
try: try:
yield yield
finally: finally:
torch.cuda.synchronize() torch.accelerator.synchronize()
elapsed = time.perf_counter() - start_time elapsed = time.perf_counter() - start_time
per_request_time = elapsed / max(len(group_request_ids), 1) per_request_time = elapsed / max(len(group_request_ids), 1)
......
...@@ -29,9 +29,6 @@ class XPUModelRunner(GPUModelRunner): ...@@ -29,9 +29,6 @@ class XPUModelRunner(GPUModelRunner):
# FIXME: To be verified. # FIXME: To be verified.
self.cascade_attn_enabled = False self.cascade_attn_enabled = False
def _sync_device(self) -> None:
torch.xpu.synchronize()
@contextmanager @contextmanager
def _torch_cuda_wrapper(): def _torch_cuda_wrapper():
...@@ -42,7 +39,6 @@ def _torch_cuda_wrapper(): ...@@ -42,7 +39,6 @@ def _torch_cuda_wrapper():
torch.cuda.current_stream = torch.xpu.current_stream torch.cuda.current_stream = torch.xpu.current_stream
torch.cuda.stream = torch.xpu.stream torch.cuda.stream = torch.xpu.stream
torch.cuda.mem_get_info = torch.xpu.mem_get_info torch.cuda.mem_get_info = torch.xpu.mem_get_info
torch.cuda.synchronize = torch.xpu.synchronize
if supports_xpu_graph(): if supports_xpu_graph():
torch.cuda.graph = torch.xpu.graph torch.cuda.graph = torch.xpu.graph
torch.cuda.CUDAGraph = torch.xpu.XPUGraph torch.cuda.CUDAGraph = torch.xpu.XPUGraph
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment