Unverified Commit 53ec16a7 authored by Kunshang Ji's avatar Kunshang Ji Committed by GitHub
Browse files

[Hardware] Replace torch.cuda.device_count/current_device/set_device API (#36145)


Signed-off-by: default avatarKunshang Ji <jikunshang95@gmail.com>
Signed-off-by: default avatarKunshang Ji <kunshang.ji@intel.com>
parent 2e693f48
......@@ -28,7 +28,7 @@ class TrtLlmGenExperts(mk.FusedMoEExpertsModular):
max_capture_size,
):
super().__init__(moe_config, quant_config)
self.device = torch.cuda.current_device()
self.device = torch.accelerator.current_device_index()
self.num_experts = moe_config.num_local_experts
self.gemm1_alpha = torch.tensor(
[1.702] * self.num_experts, dtype=torch.float32, device=self.device
......
......@@ -202,7 +202,7 @@ class RMSNorm(CustomOp):
# external Oink initialization work in this case.
else:
try:
device_index = torch.cuda.current_device()
device_index = torch.accelerator.current_device_index()
if _oink_ops.is_oink_available_for_device(device_index):
self._use_oink_rmsnorm = True
self._use_oink_fused_add_rmsnorm = (
......
......@@ -36,7 +36,8 @@ class DualChunkRotaryEmbedding(CustomOp):
self.chunk_size = chunk_size
self.local_size = local_size
self.dtype = dtype
self.device = torch.device(f"cuda:{torch.cuda.current_device()}")
device_idx = torch.accelerator.current_device_index()
self.device = torch.device(f"cuda:{device_idx}")
(q_cache, qc_cache, k_cache, qc_no_clamp_cache, q_inter_cache) = (
self._compute_cos_sin_cache()
)
......
......@@ -539,6 +539,8 @@ def deserialize_tensorizer_model(
)
before_mem = get_mem_usage()
start = time.perf_counter()
device_index = torch.accelerator.current_device_index()
device_type = current_platform.device_type
with (
open_stream(
tensorizer_config.tensorizer_uri, mode="rb", **tensorizer_args.stream_kwargs
......@@ -546,9 +548,7 @@ def deserialize_tensorizer_model(
TensorDeserializer(
stream,
dtype=tensorizer_config.dtype,
device=f"xpu:{torch.xpu.current_device()}"
if current_platform.is_xpu()
else f"cuda:{torch.cuda.current_device()}",
device=f"{device_type}:{device_index}",
**tensorizer_args.deserialization_kwargs,
) as deserializer,
):
......
......@@ -624,7 +624,7 @@ def cuda_device_count_stateless() -> int:
"""Get number of CUDA devices, caching based on the value of
CUDA_VISIBLE_DEVICES at the time of call.
This should be used instead of torch.cuda.device_count()
This should be used instead of torch.accelerator.device_count()
unless CUDA_VISIBLE_DEVICES has already been set to the desired
value."""
......
......@@ -134,7 +134,7 @@ class CoreEngineProcManager:
for proc, local_dp_rank in zip(self.processes, local_dp_ranks):
# Adjust device control in DP for non-CUDA platforms
# as well as external and ray launchers
# For CUDA platforms, we use torch.cuda.set_device()
# For CUDA platforms, we use torch.accelerator.set_device_index()()
if is_dp and (
not current_platform.is_cuda_alike()
or vllm_config.parallel_config.use_ray
......
......@@ -73,8 +73,8 @@ class SMControlContextManager:
assert current_platform.is_cuda(), (
"SM control is currently only supported on CUDA"
)
total_sms = num_compute_units(torch.cuda.current_device())
device = torch.accelerator.current_device_index()
total_sms = num_compute_units(device)
assert comm_sms < total_sms
self.total_sms = total_sms
......@@ -204,7 +204,7 @@ class UBatchWrapper:
@torch.inference_mode()
def _capture_ubatch_thread(results, ubatch_metadata):
torch.cuda.set_device(self.device)
torch.accelerator.set_device_index(self.device)
ubatch_context = ubatch_metadata.context
with torch.cuda.stream(ubatch_context.compute_stream):
_ = torch.cuda.current_blas_handle()
......
......@@ -239,11 +239,11 @@ class Worker(WorkerBase):
# DP_LOCAL_RANK * TP_PP_WORLD_SIZE + TP_LOCAL_RANK
self.local_rank += dp_local_rank * tp_pp_world_size
assert self.local_rank < torch.cuda.device_count(), (
assert self.local_rank < torch.accelerator.device_count(), (
f"DP adjusted local rank {self.local_rank} is out of bounds. "
)
visible_device_count = (
torch.cuda.device_count() if torch.cuda.is_available() else 0
torch.accelerator.device_count() if torch.cuda.is_available() else 0
)
assert self.parallel_config.local_world_size <= visible_device_count, (
f"local_world_size ({self.parallel_config.local_world_size}) must "
......@@ -252,7 +252,7 @@ class Worker(WorkerBase):
)
self.device = torch.device(f"cuda:{self.local_rank}")
current_platform.set_device(self.device)
torch.accelerator.set_device_index(self.device)
current_platform.check_if_supports_dtype(self.model_config.dtype)
......
......@@ -60,7 +60,7 @@ class XPUWorker(Worker):
and current_platform.is_xpu()
):
self.device = torch.device(f"xpu:{self.local_rank}")
current_platform.set_device(self.device)
torch.accelerator.set_device_index(self.device)
current_platform.check_if_supports_dtype(self.model_config.dtype)
torch.accelerator.empty_cache()
self.init_gpu_memory = torch.xpu.get_device_properties(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment