Unverified Commit 53ec16a7 authored by Kunshang Ji's avatar Kunshang Ji Committed by GitHub
Browse files

[Hardware] Replace torch.cuda.device_count/current_device/set_device API (#36145)


Signed-off-by: default avatarKunshang Ji <jikunshang95@gmail.com>
Signed-off-by: default avatarKunshang Ji <kunshang.ji@intel.com>
parent 2e693f48
......@@ -203,7 +203,7 @@ def test_raise_value_error_on_invalid_load_format(vllm_runner, capfd, model_ref)
torch.accelerator.empty_cache()
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires 2 GPUs")
@pytest.mark.skipif(torch.accelerator.device_count() < 2, reason="Requires 2 GPUs")
def test_tensorizer_with_tp_path_without_template(vllm_runner, capfd):
try:
model_ref = "EleutherAI/pythia-1.4b"
......@@ -231,7 +231,7 @@ def test_tensorizer_with_tp_path_without_template(vllm_runner, capfd):
) in combined_output
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires 2 GPUs")
@pytest.mark.skipif(torch.accelerator.device_count() < 2, reason="Requires 2 GPUs")
def test_deserialized_encrypted_vllm_model_with_tp_has_same_outputs(
vllm_runner, tmp_path
):
......
......@@ -11,7 +11,7 @@ from vllm.model_executor.models.utils import get_draft_quant_config
from vllm.platforms import current_platform
DEVICES = (
[f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
[f"cuda:{i}" for i in range(1 if torch.accelerator.device_count() == 1 else 2)]
if current_platform.is_cuda_alike()
else ["cpu"]
)
......@@ -61,7 +61,7 @@ def test_fc_layer_quant_config_usage(default_vllm_config, dist_init, device) ->
from vllm.model_executor.layers.linear import ReplicatedLinear
if current_platform.is_cuda_alike():
torch.cuda.set_device(device)
torch.accelerator.set_device_index(device)
torch.set_default_device(device)
......
......@@ -102,7 +102,7 @@ def run_dp_sharded_vision_model_vs_direct(
set_random_seed(0)
device = f"{current_platform.device_name}:{local_rank}"
current_platform.set_device(device)
torch.accelerator.set_device_index(device)
torch.set_default_device(device)
update_environment_variables(
......@@ -288,7 +288,7 @@ def run_dp_sharded_mrope_vision_model_vs_direct(
# Set random seed for reproducibility
set_random_seed(0)
device = f"{current_platform.device_name}:{local_rank}"
current_platform.set_device(device)
torch.accelerator.set_device_index(device)
torch.set_default_device(device)
update_environment_variables(
......@@ -365,7 +365,7 @@ def run_dp_sharded_mrope_vision_model_empty_input_worker(
"""Test run_dp_sharded_mrope_vision_model with empty input."""
# Set up distributed environment
device = f"{current_platform.device_name}:{local_rank}"
current_platform.set_device(device)
torch.accelerator.set_device_index(device)
torch.set_default_device(device)
update_environment_variables(
......@@ -414,7 +414,7 @@ def run_dp_sharded_mrope_vision_model_uneven_load_worker(
# Set up distributed environment
set_random_seed(123)
device = f"{current_platform.device_name}:{local_rank}"
current_platform.set_device(device)
torch.accelerator.set_device_index(device)
torch.set_default_device(device)
update_environment_variables(
......
......@@ -210,10 +210,9 @@ WIKITEXT_ACCURACY_CONFIGS = [
@pytest.mark.parametrize("config", WIKITEXT_ACCURACY_CONFIGS)
@pytest.mark.parametrize("tp_size", [1, 2])
def test_ocp_mx_wikitext_correctness(config: AccuracyTestConfig, tp_size: int):
if torch.cuda.device_count() < tp_size:
pytest.skip(
f"This test requires >={tp_size} gpus, got only {torch.cuda.device_count()}"
)
device_count = torch.accelerator.device_count()
if device_count < tp_size:
pytest.skip(f"This test requires >={tp_size} gpus, got only {device_count}")
task = "wikitext"
rtol = 0.1
......@@ -246,10 +245,9 @@ def test_ocp_mx_wikitext_correctness(config: AccuracyTestConfig, tp_size: int):
reason="Read access to huggingface.co/amd is required for this test.",
)
def test_mxfp4_gsm8k_correctness(config: AccuracyTestConfig):
if torch.cuda.device_count() < 8:
pytest.skip(
f"This test requires >=8 gpus, got only {torch.cuda.device_count()}"
)
device_count = torch.accelerator.device_count()
if device_count < 8:
pytest.skip(f"This test requires >=8 gpus, got only {device_count}")
task = "gsm8k"
rtol = 0.03
......
......@@ -32,7 +32,7 @@ MTP_SIMILARITY_RATE = 0.8
def _skip_if_insufficient_gpus_for_tp(tp_size: int):
"""Skip test if available GPUs < tp_size on ROCm."""
available_gpus = torch.cuda.device_count()
available_gpus = torch.accelerator.device_count()
if available_gpus < tp_size:
pytest.skip(
f"Test requires {tp_size} GPUs, but only {available_gpus} available"
......
......@@ -148,7 +148,7 @@ def test_shared_storage_connector_hashes(tmp_path, attn_backend):
)
# don't put this import at the top level
# it will call torch.cuda.device_count()
# it will call torch.accelerator.device_count()
from transformers import AutoProcessor
# Create processor to handle the chat prompt
......
......@@ -1570,7 +1570,7 @@ def test_register_kv_caches(
]
],
cache_dtype=torch.bfloat16,
device=torch.cuda.current_device(),
device=torch.accelerator.current_device_index(),
kernel_block_sizes=[block_size],
)
)
......
......@@ -141,7 +141,7 @@ def get_attention_backend_params() -> list[str]:
def get_tp_size_params() -> list[pytest.param]:
num_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 1
num_gpus = torch.accelerator.device_count() if torch.cuda.is_available() else 1
return [pytest.param(tp, id=f"tp{tp}") for tp in TP_SIZES if tp <= num_gpus]
......
......@@ -117,7 +117,8 @@ def worker_process(
@pytest.mark.skipif(
torch.cuda.device_count() < 2, reason="Need at least 2 GPUs for tensor parallelism"
torch.accelerator.device_count() < 2,
reason="Need at least 2 GPUs for tensor parallelism",
)
def test_init_distributed_is_called_before_memory_snapshot():
"""Test that distributed env is setup before memory snapshot.
......
......@@ -8,8 +8,8 @@ import regex as re
# Regex: match `torch.cuda.xxx` but allow `torch.accelerator.xxx`
# --------------------------------------------------------------------------- #
_TORCH_CUDA_PATTERNS = [
r"\btorch\.cuda\.(empty_cache|synchronize|device\()\b",
r"\bwith\btorch\.cuda\.device\b",
r"\btorch\.cuda\.(empty_cache|synchronize|device_count|current_device|set_device|device\()\b",
r"\bwith\storch\.cuda\.device\b",
]
ALLOWED_FILES = {"vllm/platforms/", "vllm/device_allocator/"}
......@@ -25,7 +25,9 @@ def scan_file(path: str) -> int:
print(
f"{path}:{line_num}: "
"\033[91merror:\033[0m " # red color
"Found torch.cuda API call"
"Found torch.cuda API call. Please refer RFC "
"https://github.com/vllm-project/vllm/issues/30679, use "
"torch.accelerator API instead."
)
return 1
return 0
......
......@@ -491,7 +491,7 @@ class FlashInferAllToAllManager(All2AllManagerBase):
self.initialize(
world_size=self.world_size,
rank=self.rank,
gpus_per_node=torch.cuda.device_count,
gpus_per_node=torch.accelerator.device_count,
)
return self.initialized
......
......@@ -151,7 +151,7 @@ class nccl_symm_mem_context:
self.pynccl_comm = pynccl_comm
self._mem_pool_ctx = torch.cuda.use_mem_pool(get_nccl_mem_pool())
self.is_graph_capture = torch.cuda.is_current_stream_capturing()
self.device = torch.cuda.current_device()
self.device = torch.accelerator.current_device_index()
def __enter__(self):
if self.disabled:
......
......@@ -50,7 +50,7 @@ class SymmMemCommunicator:
device = torch.device(f"cuda:{device}")
elif isinstance(device, str):
device = torch.device(device)
torch.cuda.set_device(device)
torch.accelerator.set_device_index(device)
self.dtype = torch.bfloat16
self.device = device
self.group = group
......
......@@ -33,7 +33,7 @@ def start_async_worker(
def thread_target() -> None:
assert device_index is not None
torch.cuda.set_device(device_index)
torch.accelerator.set_device_index(device_index)
cuda_stream = torch.cuda.Stream(device=device_index)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
......
......@@ -314,7 +314,7 @@ class EplbState:
if self.device.type == "cuda":
self.cuda_device_index = self.device.index
if self.cuda_device_index is None and torch.cuda.is_available():
self.cuda_device_index = torch.cuda.current_device()
self.cuda_device_index = torch.accelerator.current_device_index()
@staticmethod
def build_initial_global_physical_to_logical_map(
......
......@@ -483,9 +483,9 @@ def _init_lmcache_engine(
)
# Change current device.
num_gpus = torch.cuda.device_count()
num_gpus = torch.accelerator.device_count()
local_rank = parallel_config.rank % num_gpus
torch.cuda.set_device(local_rank)
torch.accelerator.set_device_index(local_rank)
device = torch.device(f"cuda:{local_rank}")
metadata = LMCacheEngineMetadata(
model_config.model,
......
......@@ -169,7 +169,7 @@ class IPCWeightTransferEngine(
update_info.shapes,
update_info.ipc_handles,
):
device_index = torch.cuda.current_device()
device_index = torch.accelerator.current_device_index()
props = torch.cuda.get_device_properties(device_index)
physical_gpu_id = str(props.uuid)
......@@ -242,7 +242,7 @@ class IPCWeightTransferEngine(
args = trainer_args
# Get physical GPU UUID
device_index = torch.cuda.current_device()
device_index = torch.accelerator.current_device_index()
props = torch.cuda.get_device_properties(device_index)
gpu_uuid = str(props.uuid)
......
......@@ -140,13 +140,14 @@ class NCCLWeightTransferEngine(
worker_rank = dp_rank * world_size_per_dp + rank_within_dp
rank = worker_rank + init_info.rank_offset
# Create stateless process group
device = torch.accelerator.current_device_index()
self.model_update_group = (
NCCLWeightTransferEngine._stateless_init_process_group(
init_info.master_address,
init_info.master_port,
rank,
init_info.world_size,
torch.cuda.current_device(),
device=device,
)
)
......@@ -275,7 +276,7 @@ class NCCLWeightTransferEngine(
Initialize NCCL process group for trainer-side weight transfer.
The trainer is always rank 0 in the process group. Uses the current
CUDA device (torch.cuda.current_device()).
CUDA device (torch.accelerator.current_device_index()).
Args:
init_info: Either an NCCLWeightTransferInitInfo object or a dict with keys:
......@@ -309,8 +310,13 @@ class NCCLWeightTransferEngine(
world_size = init_info.world_size
# Trainer is always rank 0
device = torch.accelerator.current_device_index()
return NCCLWeightTransferEngine._stateless_init_process_group(
master_address, master_port, 0, world_size, torch.cuda.current_device()
master_address,
master_port,
0,
world_size,
device,
)
@staticmethod
......
......@@ -190,7 +190,7 @@ class StaticSinkAttention(Attention, CustomOp):
sink_kv_slot_mapping = torch.arange(
self.block_size,
self.sink_len + self.block_size,
device=torch.cuda.current_device(),
device=torch.accelerator.current_device_index(),
dtype=torch.long,
)
triton_reshape_and_cache_flash_diffkv(
......
......@@ -295,14 +295,17 @@ class DefaultMoERunner(MoERunner):
states_shape = (moe.max_num_tokens, self.moe_config.hidden_dim)
logits_shape = (moe.max_num_tokens, self.moe_config.num_logical_experts)
device = torch.accelerator.current_device_index()
self.batched_hidden_states = torch.zeros(
states_shape, dtype=moe.in_dtype, device=torch.cuda.current_device()
states_shape,
dtype=moe.in_dtype,
device=device,
)
self.batched_router_logits = torch.zeros(
logits_shape,
dtype=moe.router_logits_dtype,
device=torch.cuda.current_device(),
device=device,
)
def must_reduce_shared_expert_outputs(self) -> bool:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment