Unverified Commit 53ec16a7 authored by Kunshang Ji's avatar Kunshang Ji Committed by GitHub
Browse files

[Hardware] Replace torch.cuda.device_count/current_device/set_device API (#36145)


Signed-off-by: default avatarKunshang Ji <jikunshang95@gmail.com>
Signed-off-by: default avatarKunshang Ji <kunshang.ji@intel.com>
parent 2e693f48
......@@ -257,7 +257,7 @@ def test_eplb_fml(
intermediate_size: int,
column_major_scales: bool,
):
if torch.cuda.device_count() < world_size:
if torch.accelerator.device_count() < world_size:
pytest.skip(f"Need at least {world_size} GPUs to run the test")
num_local_experts = num_experts // world_size
......
......@@ -253,7 +253,7 @@ def test_eplb_fml(
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP4", "1")
monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", backend)
if torch.cuda.device_count() < world_size:
if torch.accelerator.device_count() < world_size:
pytest.skip(f"Need at least {world_size} GPUs to run the test")
num_local_experts = num_experts // world_size
......
......@@ -38,7 +38,7 @@ def nccl_symm_mem_allreduce_worker(local_rank: int, world_size: int):
m.delenv("CUDA_VISIBLE_DEVICES", raising=False)
dtype = torch.bfloat16
device = torch.device(f"cuda:{local_rank}")
torch.cuda.set_device(device)
torch.accelerator.set_device_index(device)
torch.set_default_device(device)
torch.set_default_dtype(dtype)
update_environment_variables(
......@@ -84,7 +84,7 @@ def nccl_symm_mem_allreduce_worker(local_rank: int, world_size: int):
@pytest.mark.parametrize("world_size", [2])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA")
def test_nccl_symm_mem_allreduce(monkeypatch: pytest.MonkeyPatch, world_size):
if world_size > torch.cuda.device_count():
if world_size > torch.accelerator.device_count():
pytest.skip("Not enough GPUs to run the test.")
# Enable SymmMemCommunicator
......
......@@ -54,7 +54,7 @@ def worker_fn_wrapper(fn):
update_environment_variables(env)
local_rank = os.environ["LOCAL_RANK"]
device = torch.device(f"cuda:{local_rank}")
torch.cuda.set_device(device)
torch.accelerator.set_device_index(device)
init_distributed_environment()
fn()
......@@ -73,7 +73,7 @@ def worker_fn():
@pytest.mark.skipif(
torch.cuda.device_count() < 2, reason="Need at least 2 GPUs to run the test."
torch.accelerator.device_count() < 2, reason="Need at least 2 GPUs to run the test."
)
def test_pynccl():
distributed_run(worker_fn, 2)
......@@ -102,7 +102,7 @@ def multiple_allreduce_worker_fn():
@pytest.mark.skipif(
torch.cuda.device_count() < 4, reason="Need at least 4 GPUs to run the test."
torch.accelerator.device_count() < 4, reason="Need at least 4 GPUs to run the test."
)
def test_pynccl_multiple_allreduce():
# this tests pynccl for multiple tp groups, in a standalone way
......@@ -130,7 +130,7 @@ def multiple_allreduce_with_vllm_worker_fn():
@pytest.mark.skipif(
torch.cuda.device_count() < 4, reason="Need at least 4 GPUs to run the test."
torch.accelerator.device_count() < 4, reason="Need at least 4 GPUs to run the test."
)
def test_pynccl_multiple_allreduce_with_vllm():
# this tests pynccl for multiple tp groups, together with vllm
......@@ -185,7 +185,7 @@ def all_gather_worker_fn():
@pytest.mark.skipif(
torch.cuda.device_count() < 2, reason="Need at least 2 GPUs to run the test."
torch.accelerator.device_count() < 2, reason="Need at least 2 GPUs to run the test."
)
def test_pynccl_all_gather():
distributed_run(all_gather_worker_fn, 2)
......@@ -220,7 +220,7 @@ def all_gatherv_worker_fn():
@pytest.mark.skipif(
torch.cuda.device_count() < 2, reason="Need at least 2 GPUs to run the test."
torch.accelerator.device_count() < 2, reason="Need at least 2 GPUs to run the test."
)
def test_pynccl_all_gatherv():
distributed_run(all_gatherv_worker_fn, 2)
......@@ -260,7 +260,7 @@ def reduce_scatter_worker_fn():
@pytest.mark.skipif(
torch.cuda.device_count() < 2, reason="Need at least 2 GPUs to run the test."
torch.accelerator.device_count() < 2, reason="Need at least 2 GPUs to run the test."
)
def test_pynccl_reduce_scatter():
distributed_run(reduce_scatter_worker_fn, 2)
......@@ -298,14 +298,14 @@ def reduce_scatterv_worker_fn():
@pytest.mark.skipif(
torch.cuda.device_count() < 2, reason="Need at least 2 GPUs to run the test."
torch.accelerator.device_count() < 2, reason="Need at least 2 GPUs to run the test."
)
def test_pynccl_reduce_scatterv():
distributed_run(reduce_scatterv_worker_fn, 2)
@pytest.mark.skipif(
torch.cuda.device_count() < 2, reason="Need at least 2 GPUs to run the test."
torch.accelerator.device_count() < 2, reason="Need at least 2 GPUs to run the test."
)
def test_pynccl_with_cudagraph():
distributed_run(worker_fn_with_cudagraph, 2)
......@@ -330,7 +330,7 @@ def send_recv_worker_fn():
@pytest.mark.skipif(
torch.cuda.device_count() < 2, reason="Need at least 2 GPUs to run the test."
torch.accelerator.device_count() < 2, reason="Need at least 2 GPUs to run the test."
)
def test_pynccl_send_recv():
distributed_run(send_recv_worker_fn, 2)
......@@ -363,14 +363,14 @@ def multiple_send_recv_worker_fn():
@pytest.mark.skipif(
torch.cuda.device_count() < 4, reason="Need at least 4 GPUs to run the test."
torch.accelerator.device_count() < 4, reason="Need at least 4 GPUs to run the test."
)
def test_pynccl_multiple_send_recv():
distributed_run(multiple_send_recv_worker_fn, 4)
@pytest.mark.skipif(
torch.cuda.device_count() < 4, reason="Need at least 4 GPUs to run the test."
torch.accelerator.device_count() < 4, reason="Need at least 4 GPUs to run the test."
)
def test_pynccl_broadcast():
distributed_run(broadcast_worker_fn, 4)
......
......@@ -39,7 +39,7 @@ def graph_quickreduce(
with monkeypatch.context() as m:
m.delenv("CUDA_VISIBLE_DEVICES", raising=False)
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
torch.accelerator.set_device_index(device)
init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port)
ensure_model_parallel_initialized(tp_size, pp_size)
group = get_tp_group().device_group
......@@ -65,12 +65,10 @@ def graph_quickreduce(
for sz in test_sizes:
for dtype in [torch.float16, torch.bfloat16]:
with graph_capture(device=device) as graph_capture_context:
inp1 = torch.randint(
1, 23, (sz,), dtype=dtype, device=torch.cuda.current_device()
)
inp2 = torch.randint(
-23, 1, (sz,), dtype=dtype, device=torch.cuda.current_device()
)
device_idx = torch.accelerator.current_device_index()
inp1 = torch.randint(1, 23, (sz,), dtype=dtype, device=device_idx)
inp2 = torch.randint(-23, 1, (sz,), dtype=dtype, device=device_idx)
torch.accelerator.synchronize()
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, stream=graph_capture_context.stream):
......@@ -95,7 +93,7 @@ def eager_quickreduce(
with monkeypatch.context() as m:
m.delenv("CUDA_VISIBLE_DEVICES", raising=False)
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
torch.accelerator.set_device_index(device)
init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port)
......@@ -130,7 +128,7 @@ def test_custom_quick_allreduce(
quant_mode,
):
world_size = tp_size * pipeline_parallel_size
if world_size > torch.cuda.device_count():
if world_size > torch.accelerator.device_count():
pytest.skip("Not enough GPUs to run the test.")
monkeypatch.setenv("VLLM_ROCM_QUICK_REDUCE_QUANTIZATION", quant_mode)
......@@ -145,7 +143,7 @@ def qr_variable_input(rank, world_size):
has been observed with the gpt_oss model).
"""
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
torch.accelerator.set_device_index(device)
qr_max_size = None # MB
_ptr = ops.init_custom_qr(rank, world_size, qr_max_size)
ranks = []
......@@ -169,14 +167,13 @@ def qr_variable_input(rank, world_size):
s1 = 1024
while num < 50000: # 50000 is sufficient to identify issues.
dtype = torch.float16
device_idx = torch.accelerator.current_device_index()
if num % 2 == 0:
s2 = 1024
inp1 = torch.zeros(
(s1, s2), dtype=dtype, device=torch.cuda.current_device()
)
inp1 = torch.zeros((s1, s2), dtype=dtype, device=device_idx)
else:
s2 = 2048
inp1 = torch.ones((s1, s2), dtype=dtype, device=torch.cuda.current_device())
inp1 = torch.ones((s1, s2), dtype=dtype, device=device_idx)
result = torch.empty_like(inp1)
# FP = 0 INT8 = 1 INT6 = 2 INT4 = 3 NONE = 4
ops.qr_all_reduce(_ptr, inp1, result, 3, cast_bf2half=True)
......@@ -198,7 +195,7 @@ def qr_variable_input(rank, world_size):
@pytest.mark.parametrize("pipeline_parallel_size", [1])
def test_custom_quick_allreduce_variable_input(tp_size, pipeline_parallel_size):
world_size = tp_size * pipeline_parallel_size
if world_size > torch.cuda.device_count():
if world_size > torch.accelerator.device_count():
pytest.skip("Not enough GPUs to run the test.")
multiprocessing.set_start_method("spawn", force=True)
......
......@@ -39,7 +39,7 @@ def symm_mem_allreduce_worker(local_rank: int, world_size: int, q: mp.Queue):
m.delenv("CUDA_VISIBLE_DEVICES", raising=False)
dtype = torch.bfloat16
device = torch.device(f"cuda:{local_rank}")
torch.cuda.set_device(device)
torch.accelerator.set_device_index(device)
torch.set_default_device(device)
torch.set_default_dtype(dtype)
update_environment_variables(
......@@ -105,7 +105,7 @@ def test_symm_mem_allreduce(
monkeypatch: pytest.MonkeyPatch, tp_size, pipeline_parallel_size
):
world_size = tp_size * pipeline_parallel_size
if world_size > torch.cuda.device_count():
if world_size > torch.accelerator.device_count():
pytest.skip("Not enough GPUs to run the test.")
q = mp.get_context("spawn").Queue()
mp.spawn(symm_mem_allreduce_worker, args=(world_size, q), nprocs=world_size)
......@@ -126,7 +126,7 @@ def test_symm_mem_allreduce(
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA")
def test_dp_with_symm_mem_allreduce(monkeypatch: pytest.MonkeyPatch):
world_size = 4
if world_size > torch.cuda.device_count():
if world_size > torch.accelerator.device_count():
pytest.skip("Not enough GPUs to run the test.")
# Verify that the DataParallel runs without error
engine_args = EngineArgs(
......
......@@ -66,7 +66,7 @@ def cpu_worker(rank, WORLD_SIZE, port1, port2):
def gpu_worker(rank, WORLD_SIZE, port1, port2):
torch.cuda.set_device(rank)
torch.accelerator.set_device_index(rank)
pg1 = StatelessProcessGroup.create(
host="127.0.0.1", port=port1, rank=rank, world_size=WORLD_SIZE
)
......
......@@ -203,7 +203,7 @@ class TestEngineRegistry:
def test_nccl_receive_weights_without_init_raises():
"""Test that receive_weights raises if init_transfer_engine wasn't called."""
if torch.cuda.device_count() < 1:
if torch.accelerator.device_count() < 1:
pytest.skip("Need at least 1 GPU for this test")
config = WeightTransferConfig(backend="nccl")
......@@ -336,7 +336,7 @@ def inference_receive_tensor(
@pytest.mark.skipif(
torch.cuda.device_count() < 2,
torch.accelerator.device_count() < 2,
reason="Need at least 2 GPUs to run NCCL weight transfer test.",
)
def test_nccl_weight_transfer_between_processes():
......@@ -382,7 +382,7 @@ class TestIPCWeightTransferUpdateInfoValidation:
def test_valid_update_info(self):
"""Test creating valid IPCWeightTransferUpdateInfo."""
if torch.cuda.device_count() < 1:
if torch.accelerator.device_count() < 1:
pytest.skip("Need at least 1 GPU for this test")
# Create a dummy tensor and IPC handle
......@@ -404,7 +404,7 @@ class TestIPCWeightTransferUpdateInfoValidation:
def test_mismatched_dtype_names_raises(self):
"""Test that mismatched dtype_names length raises ValueError."""
if torch.cuda.device_count() < 1:
if torch.accelerator.device_count() < 1:
pytest.skip("Need at least 1 GPU for this test")
dummy_tensor = torch.ones(10, 10, device="cuda:0")
......@@ -422,7 +422,7 @@ class TestIPCWeightTransferUpdateInfoValidation:
def test_mismatched_shapes_raises(self):
"""Test that mismatched shapes length raises ValueError."""
if torch.cuda.device_count() < 1:
if torch.accelerator.device_count() < 1:
pytest.skip("Need at least 1 GPU for this test")
dummy_tensor = torch.ones(10, 10, device="cuda:0")
......@@ -440,7 +440,7 @@ class TestIPCWeightTransferUpdateInfoValidation:
def test_mismatched_ipc_handles_raises(self):
"""Test that mismatched ipc_handles length raises ValueError."""
if torch.cuda.device_count() < 1:
if torch.accelerator.device_count() < 1:
pytest.skip("Need at least 1 GPU for this test")
dummy_tensor = torch.ones(10, 10, device="cuda:0")
......@@ -458,7 +458,7 @@ class TestIPCWeightTransferUpdateInfoValidation:
def test_valid_update_info_from_pickled(self, monkeypatch):
"""Test creating IPCWeightTransferUpdateInfo from pickled handles."""
if torch.cuda.device_count() < 1:
if torch.accelerator.device_count() < 1:
pytest.skip("Need at least 1 GPU for this test")
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
......@@ -493,7 +493,7 @@ class TestIPCWeightTransferUpdateInfoValidation:
def test_both_handles_and_pickled_raises(self):
"""Test that providing both ipc_handles and ipc_handles_pickled raises."""
if torch.cuda.device_count() < 1:
if torch.accelerator.device_count() < 1:
pytest.skip("Need at least 1 GPU for this test")
dummy_tensor = torch.ones(10, 10, device="cuda:0")
......@@ -540,7 +540,7 @@ class TestIPCEngineParsing:
def test_parse_update_info_valid(self):
"""Test parsing valid update info dict."""
if torch.cuda.device_count() < 1:
if torch.accelerator.device_count() < 1:
pytest.skip("Need at least 1 GPU for this test")
config = WeightTransferConfig(backend="ipc")
......@@ -572,7 +572,7 @@ class TestIPCEngineParsing:
def test_parse_update_info_pickled(self, monkeypatch):
"""Test parsing update info with pickled IPC handles (HTTP path)."""
if torch.cuda.device_count() < 1:
if torch.accelerator.device_count() < 1:
pytest.skip("Need at least 1 GPU for this test")
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
......@@ -731,7 +731,7 @@ def inference_receive_ipc_tensor(
@pytest.mark.skipif(
torch.cuda.device_count() < 1,
torch.accelerator.device_count() < 1,
reason="Need at least 1 GPU to run IPC weight transfer test.",
)
@pytest.mark.parametrize("mode", ["ray", "http"])
......@@ -789,7 +789,7 @@ def test_ipc_weight_transfer_between_processes(mode: str):
def test_ipc_receive_weights_missing_gpu_uuid_raises():
"""Test that receive_weights raises if GPU UUID not found in IPC handles."""
if torch.cuda.device_count() < 1:
if torch.accelerator.device_count() < 1:
pytest.skip("Need at least 1 GPU for this test")
config = WeightTransferConfig(backend="ipc")
......
......@@ -13,7 +13,7 @@ from ...utils import create_new_process_for_each_test
@pytest.mark.parametrize("backend", ["mp", "ray"])
@create_new_process_for_each_test()
def test_collective_rpc(tp_size, backend, monkeypatch):
if torch.cuda.device_count() < tp_size:
if torch.accelerator.device_count() < tp_size:
pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}")
if tp_size == 1 and backend == "ray":
pytest.skip("Skip duplicate test case")
......
......@@ -106,7 +106,7 @@ def mock_create_engine(config, parallel_config):
@create_new_process_for_each_test()
def test_get_world_size_tp1():
"""Test world_size is correctly configured for TP=1."""
if torch.cuda.device_count() < 1:
if torch.accelerator.device_count() < 1:
pytest.skip("Need at least 1 GPU for this test")
llm = LLM(
......@@ -125,7 +125,7 @@ def test_get_world_size_tp1():
def test_init_weight_transfer_engine_calls_engine():
"""Test that init_weight_transfer_engine calls the engine's
init_transfer_engine method."""
if torch.cuda.device_count() < 1:
if torch.accelerator.device_count() < 1:
pytest.skip("Need at least 1 GPU for this test")
# Run in-process so mock.patch works (spawn won't inherit the mock)
......@@ -174,7 +174,7 @@ def test_init_weight_transfer_engine_calls_engine():
@create_new_process_for_each_test()
def test_update_weights_calls_engine():
"""Test that update_weights calls the engine's receive_weights method."""
if torch.cuda.device_count() < 1:
if torch.accelerator.device_count() < 1:
pytest.skip("Need at least 1 GPU for this test")
# Run in-process so mock.patch works (spawn won't inherit the mock)
......@@ -233,7 +233,7 @@ def test_update_weights_calls_engine():
@create_new_process_for_each_test()
def test_full_weight_transfer_flow():
"""Test the complete weight transfer flow: init -> update."""
if torch.cuda.device_count() < 1:
if torch.accelerator.device_count() < 1:
pytest.skip("Need at least 1 GPU for this test")
# Run in-process so mock.patch works (spawn won't inherit the mock)
......@@ -294,7 +294,7 @@ def test_full_weight_transfer_flow():
@create_new_process_for_each_test()
def test_weight_transfer_config_backend():
"""Test that WeightTransferConfig backend is properly configured."""
if torch.cuda.device_count() < 1:
if torch.accelerator.device_count() < 1:
pytest.skip("Need at least 1 GPU for this test")
# Test with nccl backend
......
......@@ -36,7 +36,9 @@ BLOCK_SIZES = [16, 32]
USE_ALIBI = [False, True]
KV_CACHE_DTYPE = ["auto", "fp8"]
SEEDS = [0]
CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.accelerator.device_count() == 1 else 2)
]
def ref_masked_attention(
......
......@@ -35,7 +35,9 @@ NUM_BLOCKS = [1024, 10000]
NUM_MAPPINGS = [256] # Arbitrary values for testing
SEEDS = [0]
CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.accelerator.device_count() == 1 else 2)
]
# We assume fp8 is always enabled for testing.
KV_CACHE_DTYPE = ["auto", "fp8"]
......@@ -69,7 +71,7 @@ def test_reshape_and_cache(
pytest.skip()
set_random_seed(seed)
torch.set_default_device(device)
torch.cuda.set_device(device)
torch.accelerator.set_device_index(device)
# Create a random slot mapping.
num_slots = block_size * num_blocks
slot_mapping_lst = random.sample(range(num_slots), num_tokens)
......@@ -192,7 +194,7 @@ def test_reshape_and_cache_flash(
) -> None:
set_random_seed(seed)
torch.set_default_device(device)
torch.cuda.set_device(device)
torch.accelerator.set_device_index(device)
assert implementation in ["cuda", "triton"]
if implementation == "triton" and kv_cache_layout == "HND":
pytest.skip("Triton implementation only supports NHD layout.")
......@@ -553,7 +555,7 @@ def test_concat_and_cache_mla(
) -> None:
set_random_seed(seed)
torch.set_default_device(device)
torch.cuda.set_device(device)
torch.accelerator.set_device_index(device)
total_slots = num_blocks * block_size
slot_mapping_lst = random.sample(range(total_slots), num_tokens)
......@@ -632,7 +634,7 @@ def test_concat_and_cache_ds_mla(
kv_cache_dtype = "fp8_ds_mla"
set_random_seed(seed)
torch.set_default_device(device)
torch.cuda.set_device(device)
torch.accelerator.set_device_index(device)
total_slots = num_blocks * block_size
slot_mapping_lst = random.sample(range(total_slots), num_tokens)
......@@ -744,7 +746,7 @@ def test_swap_blocks_mla(
) -> None:
set_random_seed(seed)
torch.set_default_device(device)
torch.cuda.set_device(device)
torch.accelerator.set_device_index(device)
entry_size = kv_lora_rank + qk_rope_head_dim
......
......@@ -69,7 +69,7 @@ def test_cutlass_mla_decode(
init_dtype = torch.bfloat16 if torch_dtype == torch.float8_e4m3fn else torch_dtype
torch.set_default_dtype(init_dtype)
torch.set_default_device(device)
torch.cuda.set_device(device)
torch.accelerator.set_device_index(device)
torch.manual_seed(42)
random.seed(42)
......
......@@ -57,7 +57,7 @@ def test_flash_mla(
init_dtype = torch.bfloat16 if torch_dtype == torch.float8_e4m3fn else torch_dtype
torch.set_default_dtype(init_dtype)
torch.set_default_device(device)
torch.cuda.set_device(device)
torch.accelerator.set_device_index(device)
torch.manual_seed(0)
random.seed(0)
......
......@@ -21,7 +21,9 @@ NUM_HEADS = [64]
NUM_QUERIES_PER_KV = [1, 64]
HEAD_SIZES = [24, 128]
DTYPES = [torch.float16]
CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.accelerator.device_count() == 1 else 2)
]
SLIDING_WINDOW = [0, 16, 2048]
KV_CACHE_DTYPES = ["auto", "fp8", "fp8_e5m2"]
......@@ -135,7 +137,7 @@ def test_contexted_kv_attention(
# for GPU 1 would run on both GPU0 and GPU1 and things would hang
#
# see also similar issue: https://github.com/Dao-AILab/flash-attention/issues/523
torch.cuda.set_device(device)
torch.accelerator.set_device_index(device)
MAX_SEQ_LEN = 1024
MAX_CTX_LEN = 1024
......@@ -356,7 +358,7 @@ def test_contexted_kv_attention_alibi(
# for GPU 1 would run on both GPU0 and GPU1 and things would hang
#
# see also similar issue: https://github.com/Dao-AILab/flash-attention/issues/523
torch.cuda.set_device(device)
torch.accelerator.set_device_index(device)
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
# Fork from: vllm/vllm/model_executor/models/bloom.py#L44
......
......@@ -26,7 +26,9 @@ DTYPES = [torch.half, torch.bfloat16, torch.float]
NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing
D = [512, 13824] # Arbitrary values for testing
SEEDS = [0]
CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.accelerator.device_count() == 1 else 2)
]
@pytest.mark.parametrize(
......
......@@ -33,7 +33,9 @@ SCALE_UBS = [True, False]
GROUP_SIZES = [None, [1, 64], [1, 128]]
TMA_ALIGNMENTS = [0, 4]
SEEDS = [0]
CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.accelerator.device_count() == 1 else 2)
]
EPS = 1e-6
......@@ -182,7 +184,7 @@ def test_rms_norm(
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.set_default_device(device)
torch.cuda.set_device(device)
torch.accelerator.set_device_index(device)
if group_size is not None and hidden_size % group_size[1] != 0:
# skip
......
......@@ -14,7 +14,9 @@ NUM_TOKENS = [7, 83, 4096] # Arbitrary values for testing
HIDDEN_SIZES = [8, 768, 769, 5120, 5125, 8192] # Arbitrary values for testing
ADD_RESIDUAL = [False, True]
SEEDS = [0]
CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.accelerator.device_count() == 1 else 2)
]
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
......
......@@ -19,7 +19,9 @@ NUM_HEADS = [17] # Arbitrary values for testing
BATCH_SIZES = [5] # Arbitrary values for testing
SEQ_LENS = [11, 8192] # Arbitrary values for testing
SEEDS = [0]
CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.accelerator.device_count() == 1 else 2)
]
USE_KEY = [True, False]
......
......@@ -28,7 +28,8 @@ from vllm.utils.torch_utils import set_random_seed
@pytest.mark.parametrize("block_size", [16, 64, 256])
@pytest.mark.parametrize("seed", [0])
@pytest.mark.parametrize(
"device", [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
"device",
[f"cuda:{i}" for i in range(1 if torch.accelerator.device_count() == 1 else 2)],
)
@torch.inference_mode()
def test_concat_and_cache_mla_rope_fused(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment