"tests/vscode:/vscode.git/clone" did not exist on "27208be66e6529d016358d15fe87e95810698227"
Unverified Commit 53ec16a7 authored by Kunshang Ji's avatar Kunshang Ji Committed by GitHub
Browse files

[Hardware] Replace torch.cuda.device_count/current_device/set_device API (#36145)


Signed-off-by: default avatarKunshang Ji <jikunshang95@gmail.com>
Signed-off-by: default avatarKunshang Ji <kunshang.ji@intel.com>
parent 2e693f48
...@@ -28,7 +28,7 @@ class TrtLlmGenExperts(mk.FusedMoEExpertsModular): ...@@ -28,7 +28,7 @@ class TrtLlmGenExperts(mk.FusedMoEExpertsModular):
max_capture_size, max_capture_size,
): ):
super().__init__(moe_config, quant_config) super().__init__(moe_config, quant_config)
self.device = torch.cuda.current_device() self.device = torch.accelerator.current_device_index()
self.num_experts = moe_config.num_local_experts self.num_experts = moe_config.num_local_experts
self.gemm1_alpha = torch.tensor( self.gemm1_alpha = torch.tensor(
[1.702] * self.num_experts, dtype=torch.float32, device=self.device [1.702] * self.num_experts, dtype=torch.float32, device=self.device
......
...@@ -202,7 +202,7 @@ class RMSNorm(CustomOp): ...@@ -202,7 +202,7 @@ class RMSNorm(CustomOp):
# external Oink initialization work in this case. # external Oink initialization work in this case.
else: else:
try: try:
device_index = torch.cuda.current_device() device_index = torch.accelerator.current_device_index()
if _oink_ops.is_oink_available_for_device(device_index): if _oink_ops.is_oink_available_for_device(device_index):
self._use_oink_rmsnorm = True self._use_oink_rmsnorm = True
self._use_oink_fused_add_rmsnorm = ( self._use_oink_fused_add_rmsnorm = (
......
...@@ -36,7 +36,8 @@ class DualChunkRotaryEmbedding(CustomOp): ...@@ -36,7 +36,8 @@ class DualChunkRotaryEmbedding(CustomOp):
self.chunk_size = chunk_size self.chunk_size = chunk_size
self.local_size = local_size self.local_size = local_size
self.dtype = dtype self.dtype = dtype
self.device = torch.device(f"cuda:{torch.cuda.current_device()}") device_idx = torch.accelerator.current_device_index()
self.device = torch.device(f"cuda:{device_idx}")
(q_cache, qc_cache, k_cache, qc_no_clamp_cache, q_inter_cache) = ( (q_cache, qc_cache, k_cache, qc_no_clamp_cache, q_inter_cache) = (
self._compute_cos_sin_cache() self._compute_cos_sin_cache()
) )
......
...@@ -539,6 +539,8 @@ def deserialize_tensorizer_model( ...@@ -539,6 +539,8 @@ def deserialize_tensorizer_model(
) )
before_mem = get_mem_usage() before_mem = get_mem_usage()
start = time.perf_counter() start = time.perf_counter()
device_index = torch.accelerator.current_device_index()
device_type = current_platform.device_type
with ( with (
open_stream( open_stream(
tensorizer_config.tensorizer_uri, mode="rb", **tensorizer_args.stream_kwargs tensorizer_config.tensorizer_uri, mode="rb", **tensorizer_args.stream_kwargs
...@@ -546,9 +548,7 @@ def deserialize_tensorizer_model( ...@@ -546,9 +548,7 @@ def deserialize_tensorizer_model(
TensorDeserializer( TensorDeserializer(
stream, stream,
dtype=tensorizer_config.dtype, dtype=tensorizer_config.dtype,
device=f"xpu:{torch.xpu.current_device()}" device=f"{device_type}:{device_index}",
if current_platform.is_xpu()
else f"cuda:{torch.cuda.current_device()}",
**tensorizer_args.deserialization_kwargs, **tensorizer_args.deserialization_kwargs,
) as deserializer, ) as deserializer,
): ):
......
...@@ -624,7 +624,7 @@ def cuda_device_count_stateless() -> int: ...@@ -624,7 +624,7 @@ def cuda_device_count_stateless() -> int:
"""Get number of CUDA devices, caching based on the value of """Get number of CUDA devices, caching based on the value of
CUDA_VISIBLE_DEVICES at the time of call. CUDA_VISIBLE_DEVICES at the time of call.
This should be used instead of torch.cuda.device_count() This should be used instead of torch.accelerator.device_count()
unless CUDA_VISIBLE_DEVICES has already been set to the desired unless CUDA_VISIBLE_DEVICES has already been set to the desired
value.""" value."""
......
...@@ -134,7 +134,7 @@ class CoreEngineProcManager: ...@@ -134,7 +134,7 @@ class CoreEngineProcManager:
for proc, local_dp_rank in zip(self.processes, local_dp_ranks): for proc, local_dp_rank in zip(self.processes, local_dp_ranks):
# Adjust device control in DP for non-CUDA platforms # Adjust device control in DP for non-CUDA platforms
# as well as external and ray launchers # as well as external and ray launchers
# For CUDA platforms, we use torch.cuda.set_device() # For CUDA platforms, we use torch.accelerator.set_device_index()()
if is_dp and ( if is_dp and (
not current_platform.is_cuda_alike() not current_platform.is_cuda_alike()
or vllm_config.parallel_config.use_ray or vllm_config.parallel_config.use_ray
......
...@@ -73,8 +73,8 @@ class SMControlContextManager: ...@@ -73,8 +73,8 @@ class SMControlContextManager:
assert current_platform.is_cuda(), ( assert current_platform.is_cuda(), (
"SM control is currently only supported on CUDA" "SM control is currently only supported on CUDA"
) )
device = torch.accelerator.current_device_index()
total_sms = num_compute_units(torch.cuda.current_device()) total_sms = num_compute_units(device)
assert comm_sms < total_sms assert comm_sms < total_sms
self.total_sms = total_sms self.total_sms = total_sms
...@@ -204,7 +204,7 @@ class UBatchWrapper: ...@@ -204,7 +204,7 @@ class UBatchWrapper:
@torch.inference_mode() @torch.inference_mode()
def _capture_ubatch_thread(results, ubatch_metadata): def _capture_ubatch_thread(results, ubatch_metadata):
torch.cuda.set_device(self.device) torch.accelerator.set_device_index(self.device)
ubatch_context = ubatch_metadata.context ubatch_context = ubatch_metadata.context
with torch.cuda.stream(ubatch_context.compute_stream): with torch.cuda.stream(ubatch_context.compute_stream):
_ = torch.cuda.current_blas_handle() _ = torch.cuda.current_blas_handle()
......
...@@ -239,11 +239,11 @@ class Worker(WorkerBase): ...@@ -239,11 +239,11 @@ class Worker(WorkerBase):
# DP_LOCAL_RANK * TP_PP_WORLD_SIZE + TP_LOCAL_RANK # DP_LOCAL_RANK * TP_PP_WORLD_SIZE + TP_LOCAL_RANK
self.local_rank += dp_local_rank * tp_pp_world_size self.local_rank += dp_local_rank * tp_pp_world_size
assert self.local_rank < torch.cuda.device_count(), ( assert self.local_rank < torch.accelerator.device_count(), (
f"DP adjusted local rank {self.local_rank} is out of bounds. " f"DP adjusted local rank {self.local_rank} is out of bounds. "
) )
visible_device_count = ( visible_device_count = (
torch.cuda.device_count() if torch.cuda.is_available() else 0 torch.accelerator.device_count() if torch.cuda.is_available() else 0
) )
assert self.parallel_config.local_world_size <= visible_device_count, ( assert self.parallel_config.local_world_size <= visible_device_count, (
f"local_world_size ({self.parallel_config.local_world_size}) must " f"local_world_size ({self.parallel_config.local_world_size}) must "
...@@ -252,7 +252,7 @@ class Worker(WorkerBase): ...@@ -252,7 +252,7 @@ class Worker(WorkerBase):
) )
self.device = torch.device(f"cuda:{self.local_rank}") self.device = torch.device(f"cuda:{self.local_rank}")
current_platform.set_device(self.device) torch.accelerator.set_device_index(self.device)
current_platform.check_if_supports_dtype(self.model_config.dtype) current_platform.check_if_supports_dtype(self.model_config.dtype)
......
...@@ -60,7 +60,7 @@ class XPUWorker(Worker): ...@@ -60,7 +60,7 @@ class XPUWorker(Worker):
and current_platform.is_xpu() and current_platform.is_xpu()
): ):
self.device = torch.device(f"xpu:{self.local_rank}") self.device = torch.device(f"xpu:{self.local_rank}")
current_platform.set_device(self.device) torch.accelerator.set_device_index(self.device)
current_platform.check_if_supports_dtype(self.model_config.dtype) current_platform.check_if_supports_dtype(self.model_config.dtype)
torch.accelerator.empty_cache() torch.accelerator.empty_cache()
self.init_gpu_memory = torch.xpu.get_device_properties( self.init_gpu_memory = torch.xpu.get_device_properties(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment