Unverified Commit 53ec16a7 authored by Kunshang Ji's avatar Kunshang Ji Committed by GitHub
Browse files

[Hardware] Replace torch.cuda.device_count/current_device/set_device API (#36145)


Signed-off-by: default avatarKunshang Ji <jikunshang95@gmail.com>
Signed-off-by: default avatarKunshang Ji <kunshang.ji@intel.com>
parent 2e693f48
...@@ -203,7 +203,7 @@ def test_raise_value_error_on_invalid_load_format(vllm_runner, capfd, model_ref) ...@@ -203,7 +203,7 @@ def test_raise_value_error_on_invalid_load_format(vllm_runner, capfd, model_ref)
torch.accelerator.empty_cache() torch.accelerator.empty_cache()
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires 2 GPUs") @pytest.mark.skipif(torch.accelerator.device_count() < 2, reason="Requires 2 GPUs")
def test_tensorizer_with_tp_path_without_template(vllm_runner, capfd): def test_tensorizer_with_tp_path_without_template(vllm_runner, capfd):
try: try:
model_ref = "EleutherAI/pythia-1.4b" model_ref = "EleutherAI/pythia-1.4b"
...@@ -231,7 +231,7 @@ def test_tensorizer_with_tp_path_without_template(vllm_runner, capfd): ...@@ -231,7 +231,7 @@ def test_tensorizer_with_tp_path_without_template(vllm_runner, capfd):
) in combined_output ) in combined_output
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires 2 GPUs") @pytest.mark.skipif(torch.accelerator.device_count() < 2, reason="Requires 2 GPUs")
def test_deserialized_encrypted_vllm_model_with_tp_has_same_outputs( def test_deserialized_encrypted_vllm_model_with_tp_has_same_outputs(
vllm_runner, tmp_path vllm_runner, tmp_path
): ):
......
...@@ -11,7 +11,7 @@ from vllm.model_executor.models.utils import get_draft_quant_config ...@@ -11,7 +11,7 @@ from vllm.model_executor.models.utils import get_draft_quant_config
from vllm.platforms import current_platform from vllm.platforms import current_platform
DEVICES = ( DEVICES = (
[f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)] [f"cuda:{i}" for i in range(1 if torch.accelerator.device_count() == 1 else 2)]
if current_platform.is_cuda_alike() if current_platform.is_cuda_alike()
else ["cpu"] else ["cpu"]
) )
...@@ -61,7 +61,7 @@ def test_fc_layer_quant_config_usage(default_vllm_config, dist_init, device) -> ...@@ -61,7 +61,7 @@ def test_fc_layer_quant_config_usage(default_vllm_config, dist_init, device) ->
from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.linear import ReplicatedLinear
if current_platform.is_cuda_alike(): if current_platform.is_cuda_alike():
torch.cuda.set_device(device) torch.accelerator.set_device_index(device)
torch.set_default_device(device) torch.set_default_device(device)
......
...@@ -102,7 +102,7 @@ def run_dp_sharded_vision_model_vs_direct( ...@@ -102,7 +102,7 @@ def run_dp_sharded_vision_model_vs_direct(
set_random_seed(0) set_random_seed(0)
device = f"{current_platform.device_name}:{local_rank}" device = f"{current_platform.device_name}:{local_rank}"
current_platform.set_device(device) torch.accelerator.set_device_index(device)
torch.set_default_device(device) torch.set_default_device(device)
update_environment_variables( update_environment_variables(
...@@ -288,7 +288,7 @@ def run_dp_sharded_mrope_vision_model_vs_direct( ...@@ -288,7 +288,7 @@ def run_dp_sharded_mrope_vision_model_vs_direct(
# Set random seed for reproducibility # Set random seed for reproducibility
set_random_seed(0) set_random_seed(0)
device = f"{current_platform.device_name}:{local_rank}" device = f"{current_platform.device_name}:{local_rank}"
current_platform.set_device(device) torch.accelerator.set_device_index(device)
torch.set_default_device(device) torch.set_default_device(device)
update_environment_variables( update_environment_variables(
...@@ -365,7 +365,7 @@ def run_dp_sharded_mrope_vision_model_empty_input_worker( ...@@ -365,7 +365,7 @@ def run_dp_sharded_mrope_vision_model_empty_input_worker(
"""Test run_dp_sharded_mrope_vision_model with empty input.""" """Test run_dp_sharded_mrope_vision_model with empty input."""
# Set up distributed environment # Set up distributed environment
device = f"{current_platform.device_name}:{local_rank}" device = f"{current_platform.device_name}:{local_rank}"
current_platform.set_device(device) torch.accelerator.set_device_index(device)
torch.set_default_device(device) torch.set_default_device(device)
update_environment_variables( update_environment_variables(
...@@ -414,7 +414,7 @@ def run_dp_sharded_mrope_vision_model_uneven_load_worker( ...@@ -414,7 +414,7 @@ def run_dp_sharded_mrope_vision_model_uneven_load_worker(
# Set up distributed environment # Set up distributed environment
set_random_seed(123) set_random_seed(123)
device = f"{current_platform.device_name}:{local_rank}" device = f"{current_platform.device_name}:{local_rank}"
current_platform.set_device(device) torch.accelerator.set_device_index(device)
torch.set_default_device(device) torch.set_default_device(device)
update_environment_variables( update_environment_variables(
......
...@@ -210,10 +210,9 @@ WIKITEXT_ACCURACY_CONFIGS = [ ...@@ -210,10 +210,9 @@ WIKITEXT_ACCURACY_CONFIGS = [
@pytest.mark.parametrize("config", WIKITEXT_ACCURACY_CONFIGS) @pytest.mark.parametrize("config", WIKITEXT_ACCURACY_CONFIGS)
@pytest.mark.parametrize("tp_size", [1, 2]) @pytest.mark.parametrize("tp_size", [1, 2])
def test_ocp_mx_wikitext_correctness(config: AccuracyTestConfig, tp_size: int): def test_ocp_mx_wikitext_correctness(config: AccuracyTestConfig, tp_size: int):
if torch.cuda.device_count() < tp_size: device_count = torch.accelerator.device_count()
pytest.skip( if device_count < tp_size:
f"This test requires >={tp_size} gpus, got only {torch.cuda.device_count()}" pytest.skip(f"This test requires >={tp_size} gpus, got only {device_count}")
)
task = "wikitext" task = "wikitext"
rtol = 0.1 rtol = 0.1
...@@ -246,10 +245,9 @@ def test_ocp_mx_wikitext_correctness(config: AccuracyTestConfig, tp_size: int): ...@@ -246,10 +245,9 @@ def test_ocp_mx_wikitext_correctness(config: AccuracyTestConfig, tp_size: int):
reason="Read access to huggingface.co/amd is required for this test.", reason="Read access to huggingface.co/amd is required for this test.",
) )
def test_mxfp4_gsm8k_correctness(config: AccuracyTestConfig): def test_mxfp4_gsm8k_correctness(config: AccuracyTestConfig):
if torch.cuda.device_count() < 8: device_count = torch.accelerator.device_count()
pytest.skip( if device_count < 8:
f"This test requires >=8 gpus, got only {torch.cuda.device_count()}" pytest.skip(f"This test requires >=8 gpus, got only {device_count}")
)
task = "gsm8k" task = "gsm8k"
rtol = 0.03 rtol = 0.03
......
...@@ -32,7 +32,7 @@ MTP_SIMILARITY_RATE = 0.8 ...@@ -32,7 +32,7 @@ MTP_SIMILARITY_RATE = 0.8
def _skip_if_insufficient_gpus_for_tp(tp_size: int): def _skip_if_insufficient_gpus_for_tp(tp_size: int):
"""Skip test if available GPUs < tp_size on ROCm.""" """Skip test if available GPUs < tp_size on ROCm."""
available_gpus = torch.cuda.device_count() available_gpus = torch.accelerator.device_count()
if available_gpus < tp_size: if available_gpus < tp_size:
pytest.skip( pytest.skip(
f"Test requires {tp_size} GPUs, but only {available_gpus} available" f"Test requires {tp_size} GPUs, but only {available_gpus} available"
......
...@@ -148,7 +148,7 @@ def test_shared_storage_connector_hashes(tmp_path, attn_backend): ...@@ -148,7 +148,7 @@ def test_shared_storage_connector_hashes(tmp_path, attn_backend):
) )
# don't put this import at the top level # don't put this import at the top level
# it will call torch.cuda.device_count() # it will call torch.accelerator.device_count()
from transformers import AutoProcessor from transformers import AutoProcessor
# Create processor to handle the chat prompt # Create processor to handle the chat prompt
......
...@@ -1570,7 +1570,7 @@ def test_register_kv_caches( ...@@ -1570,7 +1570,7 @@ def test_register_kv_caches(
] ]
], ],
cache_dtype=torch.bfloat16, cache_dtype=torch.bfloat16,
device=torch.cuda.current_device(), device=torch.accelerator.current_device_index(),
kernel_block_sizes=[block_size], kernel_block_sizes=[block_size],
) )
) )
......
...@@ -141,7 +141,7 @@ def get_attention_backend_params() -> list[str]: ...@@ -141,7 +141,7 @@ def get_attention_backend_params() -> list[str]:
def get_tp_size_params() -> list[pytest.param]: def get_tp_size_params() -> list[pytest.param]:
num_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 1 num_gpus = torch.accelerator.device_count() if torch.cuda.is_available() else 1
return [pytest.param(tp, id=f"tp{tp}") for tp in TP_SIZES if tp <= num_gpus] return [pytest.param(tp, id=f"tp{tp}") for tp in TP_SIZES if tp <= num_gpus]
......
...@@ -117,7 +117,8 @@ def worker_process( ...@@ -117,7 +117,8 @@ def worker_process(
@pytest.mark.skipif( @pytest.mark.skipif(
torch.cuda.device_count() < 2, reason="Need at least 2 GPUs for tensor parallelism" torch.accelerator.device_count() < 2,
reason="Need at least 2 GPUs for tensor parallelism",
) )
def test_init_distributed_is_called_before_memory_snapshot(): def test_init_distributed_is_called_before_memory_snapshot():
"""Test that distributed env is setup before memory snapshot. """Test that distributed env is setup before memory snapshot.
......
...@@ -8,8 +8,8 @@ import regex as re ...@@ -8,8 +8,8 @@ import regex as re
# Regex: match `torch.cuda.xxx` but allow `torch.accelerator.xxx` # Regex: match `torch.cuda.xxx` but allow `torch.accelerator.xxx`
# --------------------------------------------------------------------------- # # --------------------------------------------------------------------------- #
_TORCH_CUDA_PATTERNS = [ _TORCH_CUDA_PATTERNS = [
r"\btorch\.cuda\.(empty_cache|synchronize|device\()\b", r"\btorch\.cuda\.(empty_cache|synchronize|device_count|current_device|set_device|device\()\b",
r"\bwith\btorch\.cuda\.device\b", r"\bwith\storch\.cuda\.device\b",
] ]
ALLOWED_FILES = {"vllm/platforms/", "vllm/device_allocator/"} ALLOWED_FILES = {"vllm/platforms/", "vllm/device_allocator/"}
...@@ -25,7 +25,9 @@ def scan_file(path: str) -> int: ...@@ -25,7 +25,9 @@ def scan_file(path: str) -> int:
print( print(
f"{path}:{line_num}: " f"{path}:{line_num}: "
"\033[91merror:\033[0m " # red color "\033[91merror:\033[0m " # red color
"Found torch.cuda API call" "Found torch.cuda API call. Please refer RFC "
"https://github.com/vllm-project/vllm/issues/30679, use "
"torch.accelerator API instead."
) )
return 1 return 1
return 0 return 0
......
...@@ -491,7 +491,7 @@ class FlashInferAllToAllManager(All2AllManagerBase): ...@@ -491,7 +491,7 @@ class FlashInferAllToAllManager(All2AllManagerBase):
self.initialize( self.initialize(
world_size=self.world_size, world_size=self.world_size,
rank=self.rank, rank=self.rank,
gpus_per_node=torch.cuda.device_count, gpus_per_node=torch.accelerator.device_count,
) )
return self.initialized return self.initialized
......
...@@ -151,7 +151,7 @@ class nccl_symm_mem_context: ...@@ -151,7 +151,7 @@ class nccl_symm_mem_context:
self.pynccl_comm = pynccl_comm self.pynccl_comm = pynccl_comm
self._mem_pool_ctx = torch.cuda.use_mem_pool(get_nccl_mem_pool()) self._mem_pool_ctx = torch.cuda.use_mem_pool(get_nccl_mem_pool())
self.is_graph_capture = torch.cuda.is_current_stream_capturing() self.is_graph_capture = torch.cuda.is_current_stream_capturing()
self.device = torch.cuda.current_device() self.device = torch.accelerator.current_device_index()
def __enter__(self): def __enter__(self):
if self.disabled: if self.disabled:
......
...@@ -50,7 +50,7 @@ class SymmMemCommunicator: ...@@ -50,7 +50,7 @@ class SymmMemCommunicator:
device = torch.device(f"cuda:{device}") device = torch.device(f"cuda:{device}")
elif isinstance(device, str): elif isinstance(device, str):
device = torch.device(device) device = torch.device(device)
torch.cuda.set_device(device) torch.accelerator.set_device_index(device)
self.dtype = torch.bfloat16 self.dtype = torch.bfloat16
self.device = device self.device = device
self.group = group self.group = group
......
...@@ -33,7 +33,7 @@ def start_async_worker( ...@@ -33,7 +33,7 @@ def start_async_worker(
def thread_target() -> None: def thread_target() -> None:
assert device_index is not None assert device_index is not None
torch.cuda.set_device(device_index) torch.accelerator.set_device_index(device_index)
cuda_stream = torch.cuda.Stream(device=device_index) cuda_stream = torch.cuda.Stream(device=device_index)
loop = asyncio.new_event_loop() loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop) asyncio.set_event_loop(loop)
......
...@@ -314,7 +314,7 @@ class EplbState: ...@@ -314,7 +314,7 @@ class EplbState:
if self.device.type == "cuda": if self.device.type == "cuda":
self.cuda_device_index = self.device.index self.cuda_device_index = self.device.index
if self.cuda_device_index is None and torch.cuda.is_available(): if self.cuda_device_index is None and torch.cuda.is_available():
self.cuda_device_index = torch.cuda.current_device() self.cuda_device_index = torch.accelerator.current_device_index()
@staticmethod @staticmethod
def build_initial_global_physical_to_logical_map( def build_initial_global_physical_to_logical_map(
......
...@@ -483,9 +483,9 @@ def _init_lmcache_engine( ...@@ -483,9 +483,9 @@ def _init_lmcache_engine(
) )
# Change current device. # Change current device.
num_gpus = torch.cuda.device_count() num_gpus = torch.accelerator.device_count()
local_rank = parallel_config.rank % num_gpus local_rank = parallel_config.rank % num_gpus
torch.cuda.set_device(local_rank) torch.accelerator.set_device_index(local_rank)
device = torch.device(f"cuda:{local_rank}") device = torch.device(f"cuda:{local_rank}")
metadata = LMCacheEngineMetadata( metadata = LMCacheEngineMetadata(
model_config.model, model_config.model,
......
...@@ -169,7 +169,7 @@ class IPCWeightTransferEngine( ...@@ -169,7 +169,7 @@ class IPCWeightTransferEngine(
update_info.shapes, update_info.shapes,
update_info.ipc_handles, update_info.ipc_handles,
): ):
device_index = torch.cuda.current_device() device_index = torch.accelerator.current_device_index()
props = torch.cuda.get_device_properties(device_index) props = torch.cuda.get_device_properties(device_index)
physical_gpu_id = str(props.uuid) physical_gpu_id = str(props.uuid)
...@@ -242,7 +242,7 @@ class IPCWeightTransferEngine( ...@@ -242,7 +242,7 @@ class IPCWeightTransferEngine(
args = trainer_args args = trainer_args
# Get physical GPU UUID # Get physical GPU UUID
device_index = torch.cuda.current_device() device_index = torch.accelerator.current_device_index()
props = torch.cuda.get_device_properties(device_index) props = torch.cuda.get_device_properties(device_index)
gpu_uuid = str(props.uuid) gpu_uuid = str(props.uuid)
......
...@@ -140,13 +140,14 @@ class NCCLWeightTransferEngine( ...@@ -140,13 +140,14 @@ class NCCLWeightTransferEngine(
worker_rank = dp_rank * world_size_per_dp + rank_within_dp worker_rank = dp_rank * world_size_per_dp + rank_within_dp
rank = worker_rank + init_info.rank_offset rank = worker_rank + init_info.rank_offset
# Create stateless process group # Create stateless process group
device = torch.accelerator.current_device_index()
self.model_update_group = ( self.model_update_group = (
NCCLWeightTransferEngine._stateless_init_process_group( NCCLWeightTransferEngine._stateless_init_process_group(
init_info.master_address, init_info.master_address,
init_info.master_port, init_info.master_port,
rank, rank,
init_info.world_size, init_info.world_size,
torch.cuda.current_device(), device=device,
) )
) )
...@@ -275,7 +276,7 @@ class NCCLWeightTransferEngine( ...@@ -275,7 +276,7 @@ class NCCLWeightTransferEngine(
Initialize NCCL process group for trainer-side weight transfer. Initialize NCCL process group for trainer-side weight transfer.
The trainer is always rank 0 in the process group. Uses the current The trainer is always rank 0 in the process group. Uses the current
CUDA device (torch.cuda.current_device()). CUDA device (torch.accelerator.current_device_index()).
Args: Args:
init_info: Either an NCCLWeightTransferInitInfo object or a dict with keys: init_info: Either an NCCLWeightTransferInitInfo object or a dict with keys:
...@@ -309,8 +310,13 @@ class NCCLWeightTransferEngine( ...@@ -309,8 +310,13 @@ class NCCLWeightTransferEngine(
world_size = init_info.world_size world_size = init_info.world_size
# Trainer is always rank 0 # Trainer is always rank 0
device = torch.accelerator.current_device_index()
return NCCLWeightTransferEngine._stateless_init_process_group( return NCCLWeightTransferEngine._stateless_init_process_group(
master_address, master_port, 0, world_size, torch.cuda.current_device() master_address,
master_port,
0,
world_size,
device,
) )
@staticmethod @staticmethod
......
...@@ -190,7 +190,7 @@ class StaticSinkAttention(Attention, CustomOp): ...@@ -190,7 +190,7 @@ class StaticSinkAttention(Attention, CustomOp):
sink_kv_slot_mapping = torch.arange( sink_kv_slot_mapping = torch.arange(
self.block_size, self.block_size,
self.sink_len + self.block_size, self.sink_len + self.block_size,
device=torch.cuda.current_device(), device=torch.accelerator.current_device_index(),
dtype=torch.long, dtype=torch.long,
) )
triton_reshape_and_cache_flash_diffkv( triton_reshape_and_cache_flash_diffkv(
......
...@@ -295,14 +295,17 @@ class DefaultMoERunner(MoERunner): ...@@ -295,14 +295,17 @@ class DefaultMoERunner(MoERunner):
states_shape = (moe.max_num_tokens, self.moe_config.hidden_dim) states_shape = (moe.max_num_tokens, self.moe_config.hidden_dim)
logits_shape = (moe.max_num_tokens, self.moe_config.num_logical_experts) logits_shape = (moe.max_num_tokens, self.moe_config.num_logical_experts)
device = torch.accelerator.current_device_index()
self.batched_hidden_states = torch.zeros( self.batched_hidden_states = torch.zeros(
states_shape, dtype=moe.in_dtype, device=torch.cuda.current_device() states_shape,
dtype=moe.in_dtype,
device=device,
) )
self.batched_router_logits = torch.zeros( self.batched_router_logits = torch.zeros(
logits_shape, logits_shape,
dtype=moe.router_logits_dtype, dtype=moe.router_logits_dtype,
device=torch.cuda.current_device(), device=device,
) )
def must_reduce_shared_expert_outputs(self) -> bool: def must_reduce_shared_expert_outputs(self) -> bool:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment