Unverified Commit 53ec16a7 authored by Kunshang Ji's avatar Kunshang Ji Committed by GitHub
Browse files

[Hardware] Replace torch.cuda.device_count/current_device/set_device API (#36145)


Signed-off-by: default avatarKunshang Ji <jikunshang95@gmail.com>
Signed-off-by: default avatarKunshang Ji <kunshang.ji@intel.com>
parent 2e693f48
...@@ -6,7 +6,9 @@ import torch ...@@ -6,7 +6,9 @@ import torch
from vllm.utils.platform_utils import is_uva_available from vllm.utils.platform_utils import is_uva_available
from vllm.utils.torch_utils import get_accelerator_view_from_cpu_tensor from vllm.utils.torch_utils import get_accelerator_view_from_cpu_tensor
CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)] CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.accelerator.device_count() == 1 else 2)
]
@pytest.mark.skipif(not is_uva_available(), reason="UVA is not available.") @pytest.mark.skipif(not is_uva_available(), reason="UVA is not available.")
......
...@@ -71,7 +71,7 @@ def mixer2_gated_norm_tensor_parallel( ...@@ -71,7 +71,7 @@ def mixer2_gated_norm_tensor_parallel(
set_random_seed(0) set_random_seed(0)
device = torch.device(f"cuda:{local_rank}") device = torch.device(f"cuda:{local_rank}")
torch.cuda.set_device(device) torch.accelerator.set_device_index(device)
torch.set_default_device(device) torch.set_default_device(device)
torch.set_default_dtype(dtype) torch.set_default_dtype(dtype)
......
...@@ -322,7 +322,7 @@ class WeightTensors: ...@@ -322,7 +322,7 @@ class WeightTensors:
) )
def to_current_device(self): def to_current_device(self):
device = torch.cuda.current_device() device = torch.accelerator.current_device_index()
self.w1 = self.w1.to(device=device) self.w1 = self.w1.to(device=device)
self.w2 = self.w2.to(device=device) self.w2 = self.w2.to(device=device)
...@@ -392,7 +392,8 @@ class RankTensors: ...@@ -392,7 +392,8 @@ class RankTensors:
Return hidden_states Return hidden_states
""" """
m, k, dtype = (config.M, config.K, config.dtype) m, k, dtype = (config.M, config.K, config.dtype)
a = torch.randn((m, k), device=torch.cuda.current_device(), dtype=dtype) / 15.0 device = torch.accelerator.current_device_index()
a = torch.randn((m, k), device=device, dtype=dtype) / 15.0
if config.quant_dtype is None: if config.quant_dtype is None:
return a, None return a, None
...@@ -428,9 +429,10 @@ class RankTensors: ...@@ -428,9 +429,10 @@ class RankTensors:
topk_weights, topk_ids, _ = fused_topk(hidden_states, score, topk, False) topk_weights, topk_ids, _ = fused_topk(hidden_states, score, topk, False)
# distribute topk_ids evenly # distribute topk_ids evenly
device = torch.accelerator.current_device_index()
for mi in range(m): for mi in range(m):
topk_ids[mi] = torch.randperm(config.E)[:topk] topk_ids[mi] = torch.randperm(config.E)[:topk]
topk_ids = topk_ids.to(device=torch.cuda.current_device()) topk_ids = topk_ids.to(device=device)
expert_map = None expert_map = None
if config.world_size > 1 and config.supports_expert_map(): if config.world_size > 1 and config.supports_expert_map():
...@@ -440,9 +442,7 @@ class RankTensors: ...@@ -440,9 +442,7 @@ class RankTensors:
s = pgi.rank * num_local_experts s = pgi.rank * num_local_experts
e = s + num_local_experts e = s + num_local_experts
expert_map[s:e] = torch.tensor(list(range(num_local_experts))) expert_map[s:e] = torch.tensor(list(range(num_local_experts)))
expert_map = expert_map.to( expert_map = expert_map.to(device=device, dtype=torch.int32)
device=torch.cuda.current_device(), dtype=torch.int32
)
return RankTensors( return RankTensors(
hidden_states=hidden_states, hidden_states=hidden_states,
...@@ -558,7 +558,9 @@ def reference_moe_impl( ...@@ -558,7 +558,9 @@ def reference_moe_impl(
def _make_gscale(num_experts: int) -> torch.Tensor: def _make_gscale(num_experts: int) -> torch.Tensor:
return torch.ones( return torch.ones(
(num_experts,), device=torch.cuda.current_device(), dtype=torch.float32 (num_experts,),
device=torch.accelerator.current_device_index(),
dtype=torch.float32,
) )
......
...@@ -66,7 +66,7 @@ def _worker_parallel_launch( ...@@ -66,7 +66,7 @@ def _worker_parallel_launch(
**kwargs: P.kwargs, **kwargs: P.kwargs,
) -> None: ) -> None:
rank = node_rank * world_local_size + local_rank rank = node_rank * world_local_size + local_rank
torch.cuda.set_device(local_rank) torch.accelerator.set_device_index(local_rank)
device = torch.device("cuda", local_rank) device = torch.device("cuda", local_rank)
torch.distributed.init_process_group( torch.distributed.init_process_group(
backend="cpu:gloo,cuda:nccl", backend="cpu:gloo,cuda:nccl",
......
...@@ -34,7 +34,8 @@ def do_profile( ...@@ -34,7 +34,8 @@ def do_profile(
record_shapes=True, record_shapes=True,
) as tprof: ) as tprof:
fn(**fn_kwargs) fn(**fn_kwargs)
torch.accelerator.synchronize(torch.cuda.current_device()) device = torch.accelerator.current_device_index()
torch.accelerator.synchronize(device=device)
# TODO (varun): Add a descriptive trace file name # TODO (varun): Add a descriptive trace file name
tprof.export_chrome_trace( tprof.export_chrome_trace(
......
...@@ -52,7 +52,7 @@ def _worker_parallel_launch( ...@@ -52,7 +52,7 @@ def _worker_parallel_launch(
**kwargs: P.kwargs, **kwargs: P.kwargs,
) -> None: ) -> None:
rank = node_rank * world_local_size + local_rank rank = node_rank * world_local_size + local_rank
torch.cuda.set_device(local_rank) torch.accelerator.set_device_index(local_rank)
device = torch.device("cuda", local_rank) device = torch.device("cuda", local_rank)
torch.distributed.init_process_group( torch.distributed.init_process_group(
backend="cpu:gloo,cuda:nccl", backend="cpu:gloo,cuda:nccl",
......
...@@ -134,10 +134,8 @@ class TestTensors: ...@@ -134,10 +134,8 @@ class TestTensors:
fp8_info = torch.finfo(torch.float8_e4m3fn) fp8_info = torch.finfo(torch.float8_e4m3fn)
fp8_max, fp8_min = fp8_info.max, fp8_info.min fp8_max, fp8_min = fp8_info.max, fp8_info.min
device = torch.accelerator.current_device_index()
rank_tokens = ( rank_tokens = torch.randn((m, k), device=device, dtype=dtype) / 10.0
torch.randn((m, k), device=torch.cuda.current_device(), dtype=dtype) / 10.0
)
rank_tokens = rank_tokens.clamp(min=fp8_min, max=fp8_max) rank_tokens = rank_tokens.clamp(min=fp8_min, max=fp8_max)
rank_token_scales = None rank_token_scales = None
...@@ -145,11 +143,13 @@ class TestTensors: ...@@ -145,11 +143,13 @@ class TestTensors:
low=0, low=0,
high=config.num_experts, high=config.num_experts,
size=(m, topk), size=(m, topk),
device=torch.cuda.current_device(), device=device,
).to(dtype=torch.int64) ).to(dtype=torch.int64)
topk_weights = torch.randn( topk_weights = torch.randn(
topk_ids.shape, dtype=torch.float32, device=torch.cuda.current_device() topk_ids.shape,
dtype=torch.float32,
device=device,
) )
return TestTensors( return TestTensors(
...@@ -296,7 +296,8 @@ def deepep_deepgemm_moe_impl( ...@@ -296,7 +296,8 @@ def deepep_deepgemm_moe_impl(
s = pgi.rank * num_local_experts s = pgi.rank * num_local_experts
e = s + num_local_experts e = s + num_local_experts
expert_map[s:e] = torch.tensor(list(range(num_local_experts))) expert_map[s:e] = torch.tensor(list(range(num_local_experts)))
return expert_map.to(device=torch.cuda.current_device(), dtype=torch.int32) device = torch.accelerator.current_device_index()
return expert_map.to(device=device, dtype=torch.int32)
quant_config = fp8_w8a8_moe_quant_config( quant_config = fp8_w8a8_moe_quant_config(
w1_scale=w1_scale, w1_scale=w1_scale,
...@@ -376,10 +377,11 @@ def _test_deepep_deepgemm_moe( ...@@ -376,10 +377,11 @@ def _test_deepep_deepgemm_moe(
set_random_seed(pgi.rank) set_random_seed(pgi.rank)
w1 = w1.to(device=torch.cuda.current_device()) device = torch.accelerator.current_device_index()
w2 = w2.to(device=torch.cuda.current_device()) w1 = w1.to(device=device)
w1_scale = w1_scale.to(device=torch.cuda.current_device()) w2 = w2.to(device=device)
w2_scale = w2_scale.to(device=torch.cuda.current_device()) w1_scale = w1_scale.to(device=device)
w2_scale = w2_scale.to(device=device)
pg = torch.distributed.new_group(list(range(pgi.world_size))) pg = torch.distributed.new_group(list(range(pgi.world_size)))
test_tensors = TestTensors.make(config, pgi.rank) test_tensors = TestTensors.make(config, pgi.rank)
......
...@@ -210,7 +210,8 @@ def deep_ep_moe_impl( ...@@ -210,7 +210,8 @@ def deep_ep_moe_impl(
s = pgi.rank * num_local_experts s = pgi.rank * num_local_experts
e = s + num_local_experts e = s + num_local_experts
expert_map[s:e] = torch.tensor(list(range(num_local_experts))) expert_map[s:e] = torch.tensor(list(range(num_local_experts)))
return expert_map.to(device=torch.cuda.current_device(), dtype=torch.int32) device = torch.accelerator.current_device_index()
return expert_map.to(device=device, dtype=torch.int32)
hidden_size = test_tensors.rank_tokens.size(1) hidden_size = test_tensors.rank_tokens.size(1)
is_quantized = w1.dtype == torch.float8_e4m3fn is_quantized = w1.dtype == torch.float8_e4m3fn
...@@ -365,15 +366,13 @@ def _deep_ep_moe( ...@@ -365,15 +366,13 @@ def _deep_ep_moe(
) )
is_quantized = w1.dtype == torch.float8_e4m3fn is_quantized = w1.dtype == torch.float8_e4m3fn
w1 = w1.to(device=torch.cuda.current_device()) device_idx = torch.accelerator.current_device_index()
w2 = w2.to(device=torch.cuda.current_device()) w1 = w1.to(device=device_idx)
w2 = w2.to(device=device_idx)
if is_quantized: if is_quantized:
w1_scale = w1_scale.to( # type: ignore assert w1_scale is not None and w2_scale is not None
device=torch.cuda.current_device() w1_scale = w1_scale.to(device=device_idx)
) w2_scale = w2_scale.to(device=device_idx)
w2_scale = w2_scale.to( # type: ignore
device=torch.cuda.current_device()
)
pg = torch.distributed.new_group(list(range(pgi.world_size))) pg = torch.distributed.new_group(list(range(pgi.world_size)))
test_tensors = TestTensors.make(config, low_latency_mode) test_tensors = TestTensors.make(config, low_latency_mode)
......
...@@ -716,7 +716,7 @@ def test_mixtral_moe( ...@@ -716,7 +716,7 @@ def test_mixtral_moe(
monkeypatch.setenv("MASTER_ADDR", "localhost") monkeypatch.setenv("MASTER_ADDR", "localhost")
monkeypatch.setenv("MASTER_PORT", "12345") monkeypatch.setenv("MASTER_PORT", "12345")
init_distributed_environment() init_distributed_environment()
init_workspace_manager(torch.cuda.current_device()) init_workspace_manager(torch.accelerator.current_device_index())
# Instantiate our and huggingface's MoE blocks # Instantiate our and huggingface's MoE blocks
vllm_config.compilation_config.static_forward_context = dict() vllm_config.compilation_config.static_forward_context = dict()
......
...@@ -71,10 +71,10 @@ def enable_pickle(monkeypatch): ...@@ -71,10 +71,10 @@ def enable_pickle(monkeypatch):
) )
@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, reason="amd-quark>=0.9 is not available") @pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, reason="amd-quark>=0.9 is not available")
def test_mxfp4_loading_and_execution_moe(vllm_runner, model_case: ModelCase): def test_mxfp4_loading_and_execution_moe(vllm_runner, model_case: ModelCase):
if torch.cuda.device_count() < model_case.tp: if torch.accelerator.device_count() < model_case.tp:
pytest.skip( pytest.skip(
f"This test requires >={model_case.tp} gpus, got only " f"This test requires >={model_case.tp} gpus, got only "
f"{torch.cuda.device_count()}" f"{torch.accelerator.device_count()}"
) )
# `cudagraph_capture_sizes=[16]` to reduce load time. # `cudagraph_capture_sizes=[16]` to reduce load time.
......
...@@ -15,7 +15,9 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( ...@@ -15,7 +15,9 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)] CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.accelerator.device_count() == 1 else 2)
]
capability = current_platform.get_device_capability() capability = current_platform.get_device_capability()
capability = capability[0] * 10 + capability[1] capability = capability[0] * 10 + capability[1]
......
...@@ -40,7 +40,9 @@ MNK_FACTORS = [ ...@@ -40,7 +40,9 @@ MNK_FACTORS = [
(512, 24576, 128), (512, 24576, 128),
] ]
CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)] CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.accelerator.device_count() == 1 else 2)
]
# -1 means full extent in that dimension # -1 means full extent in that dimension
TENSORWISE_GROUP_SHAPE = (-1, -1) TENSORWISE_GROUP_SHAPE = (-1, -1)
......
...@@ -29,7 +29,9 @@ if current_platform.is_rocm(): ...@@ -29,7 +29,9 @@ if current_platform.is_rocm():
allow_module_level=True, allow_module_level=True,
) )
CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)] CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.accelerator.device_count() == 1 else 2)
]
# TODO: in future PR refactor this and `is_quant_method_supported` in the kernel # TODO: in future PR refactor this and `is_quant_method_supported` in the kernel
# unit tests to a common utility function. Currently the use of # unit tests to a common utility function. Currently the use of
......
...@@ -13,7 +13,7 @@ except ImportError: ...@@ -13,7 +13,7 @@ except ImportError:
) )
@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Need CUDA device") @pytest.mark.skipif(torch.accelerator.device_count() < 1, reason="Need CUDA device")
def test_gather_cache_oob(): def test_gather_cache_oob():
""" """
Tests for OOB read in gather_and_maybe_dequant_cache (Issue #27909). Tests for OOB read in gather_and_maybe_dequant_cache (Issue #27909).
......
...@@ -13,7 +13,9 @@ QUANT_DTYPES = [current_platform.fp8_dtype()] ...@@ -13,7 +13,9 @@ QUANT_DTYPES = [current_platform.fp8_dtype()]
NUM_TOKENS = [1, 17, 86, 1234, 3045] # Arbitrary values for testing NUM_TOKENS = [1, 17, 86, 1234, 3045] # Arbitrary values for testing
HIDDEN_SIZES = [16, 48, 128, 1562, 4096] # Arbitrary values for testing HIDDEN_SIZES = [16, 48, 128, 1562, 4096] # Arbitrary values for testing
SEEDS = [0] SEEDS = [0]
CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)] CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.accelerator.device_count() == 1 else 2)
]
def ref_impl( def ref_impl(
......
...@@ -638,7 +638,7 @@ def use_fused_moe_lora_kernel_tensor_parallel( ...@@ -638,7 +638,7 @@ def use_fused_moe_lora_kernel_tensor_parallel(
set_random_seed(seed) set_random_seed(seed)
device = torch.device(f"cuda:{local_rank}") device = torch.device(f"cuda:{local_rank}")
torch.cuda.set_device(device) torch.accelerator.set_device_index(device)
torch.set_default_device(device) torch.set_default_device(device)
torch.set_default_dtype(dtype) torch.set_default_dtype(dtype)
......
...@@ -61,7 +61,7 @@ pytestmark = pytest.mark.skipif( ...@@ -61,7 +61,7 @@ pytestmark = pytest.mark.skipif(
) )
DEVICES = ( DEVICES = (
[f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)] [f"cuda:{i}" for i in range(1 if torch.accelerator.device_count() == 1 else 2)]
if current_platform.is_cuda_alike() if current_platform.is_cuda_alike()
else ["cpu"] else ["cpu"]
) )
...@@ -260,7 +260,7 @@ def test_embeddings( ...@@ -260,7 +260,7 @@ def test_embeddings(
# device, see: https://github.com/triton-lang/triton/issues/2925 # device, see: https://github.com/triton-lang/triton/issues/2925
# Same below. # Same below.
if current_platform.is_cuda_alike(): if current_platform.is_cuda_alike():
torch.cuda.set_device(device) torch.accelerator.set_device_index(device)
torch.set_default_device(device) torch.set_default_device(device)
max_loras = 8 max_loras = 8
...@@ -359,7 +359,7 @@ def test_lm_head_logits_processor( ...@@ -359,7 +359,7 @@ def test_lm_head_logits_processor(
default_vllm_config, dist_init, num_loras, device, vocab_size, stage default_vllm_config, dist_init, num_loras, device, vocab_size, stage
) -> None: ) -> None:
if current_platform.is_cuda_alike(): if current_platform.is_cuda_alike():
torch.cuda.set_device(device) torch.accelerator.set_device_index(device)
torch.set_default_device(device) torch.set_default_device(device)
max_loras = 8 max_loras = 8
...@@ -476,7 +476,7 @@ def test_lm_head_logits_processor_invalid_vocab_size( ...@@ -476,7 +476,7 @@ def test_lm_head_logits_processor_invalid_vocab_size(
) -> None: ) -> None:
"""Test that LogitsProcessorWithLoRA raises ValueError for invalid vocab sizes.""" """Test that LogitsProcessorWithLoRA raises ValueError for invalid vocab sizes."""
if current_platform.is_cuda_alike(): if current_platform.is_cuda_alike():
torch.cuda.set_device(device) torch.accelerator.set_device_index(device)
torch.set_default_device(device) torch.set_default_device(device)
max_loras = 8 max_loras = 8
...@@ -505,7 +505,7 @@ def test_linear_replicated( ...@@ -505,7 +505,7 @@ def test_linear_replicated(
stage, stage,
) -> None: ) -> None:
if current_platform.is_cuda_alike(): if current_platform.is_cuda_alike():
torch.cuda.set_device(device) torch.accelerator.set_device_index(device)
max_loras = 8 max_loras = 8
torch.set_default_device(device) torch.set_default_device(device)
...@@ -612,7 +612,7 @@ def test_linear_parallel( ...@@ -612,7 +612,7 @@ def test_linear_parallel(
default_vllm_config, dist_init, num_loras, orientation, fully_shard, device, stage default_vllm_config, dist_init, num_loras, orientation, fully_shard, device, stage
) -> None: ) -> None:
if current_platform.is_cuda_alike(): if current_platform.is_cuda_alike():
torch.cuda.set_device(device) torch.accelerator.set_device_index(device)
max_loras = 8 max_loras = 8
torch.set_default_device(device) torch.set_default_device(device)
...@@ -737,7 +737,7 @@ def test_column_parallel_packed( ...@@ -737,7 +737,7 @@ def test_column_parallel_packed(
default_vllm_config, dist_init, num_loras, repeats, fully_shard, device, stage default_vllm_config, dist_init, num_loras, repeats, fully_shard, device, stage
) -> None: ) -> None:
if current_platform.is_cuda_alike(): if current_platform.is_cuda_alike():
torch.cuda.set_device(device) torch.accelerator.set_device_index(device)
max_loras = 8 max_loras = 8
torch.set_default_device(device) torch.set_default_device(device)
...@@ -885,7 +885,7 @@ def test_merged_column_parallel_variable_slice( ...@@ -885,7 +885,7 @@ def test_merged_column_parallel_variable_slice(
default_vllm_config, dist_init, num_loras, num_slices, device, stage default_vllm_config, dist_init, num_loras, num_slices, device, stage
) -> None: ) -> None:
if current_platform.is_cuda_alike(): if current_platform.is_cuda_alike():
torch.cuda.set_device(device) torch.accelerator.set_device_index(device)
max_loras = 8 max_loras = 8
torch.set_default_device(device) torch.set_default_device(device)
......
...@@ -37,7 +37,7 @@ EMBEDDING_MODULES = { ...@@ -37,7 +37,7 @@ EMBEDDING_MODULES = {
DEVICES = ( DEVICES = (
[f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)] [f"cuda:{i}" for i in range(1 if torch.accelerator.device_count() == 1 else 2)]
if current_platform.is_cuda_alike() if current_platform.is_cuda_alike()
else ["cpu"] else ["cpu"]
) )
......
...@@ -34,7 +34,7 @@ def do_sample( ...@@ -34,7 +34,7 @@ def do_sample(
def test_mixtral_lora(mixtral_lora_files, tp_size): def test_mixtral_lora(mixtral_lora_files, tp_size):
"""Original test, the LoRA model has the common target modules, not all""" """Original test, the LoRA model has the common target modules, not all"""
if ( if (
torch.cuda.device_count() < tp_size torch.accelerator.device_count() < tp_size
and tp_size > 1 and tp_size > 1
and current_platform.is_cuda_alike() and current_platform.is_cuda_alike()
): ):
......
...@@ -395,7 +395,7 @@ def test_kernels( ...@@ -395,7 +395,7 @@ def test_kernels(
Tests LoRA kernels. Tests LoRA kernels.
""" """
torch.set_default_device(device) torch.set_default_device(device)
torch.cuda.set_device(device) torch.accelerator.set_device_index(device)
set_random_seed(seed) set_random_seed(seed)
if op_type == "shrink": if op_type == "shrink":
...@@ -448,7 +448,7 @@ def test_kernels_hidden_size( ...@@ -448,7 +448,7 @@ def test_kernels_hidden_size(
Tests SGMV and LoRA kernels. Tests SGMV and LoRA kernels.
""" """
torch.set_default_device(device) torch.set_default_device(device)
torch.cuda.set_device(device) torch.accelerator.set_device_index(device)
set_random_seed(seed) set_random_seed(seed)
if op_type == "shrink": if op_type == "shrink":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment