Unverified Commit 3abf8584 authored by wliao2's avatar wliao2 Committed by GitHub
Browse files

[Test] Refactor hard coded device string in test files under...


[Test] Refactor hard coded device string in test files under compile/quantization/models/model_executor folders (#38901)
Signed-off-by: default avatarLiao, Wei <wei.liao@intel.com>
parent f4b42df0
...@@ -13,6 +13,8 @@ from vllm.utils.mem_constants import GiB_bytes ...@@ -13,6 +13,8 @@ from vllm.utils.mem_constants import GiB_bytes
from ..utils import create_new_process_for_each_test, requires_fp8 from ..utils import create_new_process_for_each_test, requires_fp8
DEVICE_TYPE = current_platform.device_type
@create_new_process_for_each_test("fork" if not current_platform.is_rocm() else "spawn") @create_new_process_for_each_test("fork" if not current_platform.is_rocm() else "spawn")
def test_python_error(): def test_python_error():
...@@ -26,13 +28,13 @@ def test_python_error(): ...@@ -26,13 +28,13 @@ def test_python_error():
tensors = [] tensors = []
with allocator.use_memory_pool(): with allocator.use_memory_pool():
# allocate 70% of the total memory # allocate 70% of the total memory
x = torch.empty(alloc_bytes, dtype=torch.uint8, device="cuda") x = torch.empty(alloc_bytes, dtype=torch.uint8, device=DEVICE_TYPE)
tensors.append(x) tensors.append(x)
# release the memory # release the memory
allocator.sleep() allocator.sleep()
# allocate more memory than the total memory # allocate more memory than the total memory
y = torch.empty(alloc_bytes, dtype=torch.uint8, device="cuda") y = torch.empty(alloc_bytes, dtype=torch.uint8, device=DEVICE_TYPE)
tensors.append(y) tensors.append(y)
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
# when the allocator is woken up, it should raise an error # when the allocator is woken up, it should raise an error
...@@ -44,17 +46,17 @@ def test_python_error(): ...@@ -44,17 +46,17 @@ def test_python_error():
def test_basic_cumem(): def test_basic_cumem():
# some tensors from default memory pool # some tensors from default memory pool
shape = (1024, 1024) shape = (1024, 1024)
x = torch.empty(shape, device="cuda") x = torch.empty(shape, device=DEVICE_TYPE)
x.zero_() x.zero_()
# some tensors from custom memory pool # some tensors from custom memory pool
allocator = CuMemAllocator.get_instance() allocator = CuMemAllocator.get_instance()
with allocator.use_memory_pool(): with allocator.use_memory_pool():
# custom memory pool # custom memory pool
y = torch.empty(shape, device="cuda") y = torch.empty(shape, device=DEVICE_TYPE)
y.zero_() y.zero_()
y += 1 y += 1
z = torch.empty(shape, device="cuda") z = torch.empty(shape, device=DEVICE_TYPE)
z.zero_() z.zero_()
z += 2 z += 2
...@@ -77,16 +79,16 @@ def test_basic_cumem(): ...@@ -77,16 +79,16 @@ def test_basic_cumem():
def test_cumem_with_cudagraph(): def test_cumem_with_cudagraph():
allocator = CuMemAllocator.get_instance() allocator = CuMemAllocator.get_instance()
with allocator.use_memory_pool(): with allocator.use_memory_pool():
weight = torch.eye(1024, device="cuda") weight = torch.eye(1024, device=DEVICE_TYPE)
with allocator.use_memory_pool(tag="discard"): with allocator.use_memory_pool(tag="discard"):
cache = torch.empty(1024, 1024, device="cuda") cache = torch.empty(1024, 1024, device=DEVICE_TYPE)
def model(x): def model(x):
out = x @ weight out = x @ weight
cache[: out.size(0)].copy_(out) cache[: out.size(0)].copy_(out)
return out + 1 return out + 1
x = torch.empty(128, 1024, device="cuda") x = torch.empty(128, 1024, device=DEVICE_TYPE)
# warmup # warmup
model(x) model(x)
......
...@@ -31,6 +31,7 @@ from vllm.platforms import current_platform ...@@ -31,6 +31,7 @@ from vllm.platforms import current_platform
from vllm.utils.system_utils import update_environment_variables from vllm.utils.system_utils import update_environment_variables
from vllm.utils.torch_utils import set_random_seed from vllm.utils.torch_utils import set_random_seed
DEVICE_TYPE = current_platform.device_type
FP8_DTYPE = current_platform.fp8_dtype() FP8_DTYPE = current_platform.fp8_dtype()
prompts = [ prompts = [
...@@ -299,7 +300,7 @@ def async_tp_pass_on_test_model( ...@@ -299,7 +300,7 @@ def async_tp_pass_on_test_model(
): ):
set_random_seed(0) set_random_seed(0)
device = torch.device(f"cuda:{local_rank}") device = torch.device(f"{DEVICE_TYPE}:{local_rank}")
torch.accelerator.set_device_index(device) torch.accelerator.set_device_index(device)
torch.set_default_device(device) torch.set_default_device(device)
torch.set_default_dtype(dtype) torch.set_default_dtype(dtype)
...@@ -324,7 +325,7 @@ def async_tp_pass_on_test_model( ...@@ -324,7 +325,7 @@ def async_tp_pass_on_test_model(
fuse_gemm_comms=True, fuse_gemm_comms=True,
), ),
) )
vllm_config.device_config = DeviceConfig(device=torch.device("cuda")) vllm_config.device_config = DeviceConfig(device=torch.device(DEVICE_TYPE))
# this is a fake model name to construct the model config # this is a fake model name to construct the model config
# in the vllm_config, it's not really used. # in the vllm_config, it's not really used.
......
...@@ -37,6 +37,8 @@ from vllm.platforms import current_platform ...@@ -37,6 +37,8 @@ from vllm.platforms import current_platform
from vllm.utils.system_utils import update_environment_variables from vllm.utils.system_utils import update_environment_variables
from vllm.utils.torch_utils import set_random_seed from vllm.utils.torch_utils import set_random_seed
DEVICE_TYPE = current_platform.device_type
class TestAllReduceRMSNormModel(torch.nn.Module): class TestAllReduceRMSNormModel(torch.nn.Module):
def __init__( def __init__(
...@@ -268,7 +270,7 @@ def all_reduce_fusion_pass_on_test_model( ...@@ -268,7 +270,7 @@ def all_reduce_fusion_pass_on_test_model(
): ):
set_random_seed(0) set_random_seed(0)
device = torch.device(f"cuda:{local_rank}") device = torch.device(f"{DEVICE_TYPE}:{local_rank}")
torch.accelerator.set_device_index(device) torch.accelerator.set_device_index(device)
torch.set_default_device(device) torch.set_default_device(device)
torch.set_default_dtype(dtype) torch.set_default_dtype(dtype)
...@@ -300,7 +302,7 @@ def all_reduce_fusion_pass_on_test_model( ...@@ -300,7 +302,7 @@ def all_reduce_fusion_pass_on_test_model(
vllm_config.compilation_config.pass_config = PassConfig( vllm_config.compilation_config.pass_config = PassConfig(
fuse_allreduce_rms=True, eliminate_noops=True fuse_allreduce_rms=True, eliminate_noops=True
) )
vllm_config.device_config = DeviceConfig(device=torch.device("cuda")) vllm_config.device_config = DeviceConfig(device=torch.device(DEVICE_TYPE))
vllm_config.parallel_config.rank = local_rank # Setup rank for debug path vllm_config.parallel_config.rank = local_rank # Setup rank for debug path
# this is a fake model name to construct the model config # this is a fake model name to construct the model config
......
...@@ -35,6 +35,8 @@ from vllm.platforms import current_platform ...@@ -35,6 +35,8 @@ from vllm.platforms import current_platform
from vllm.utils.system_utils import update_environment_variables from vllm.utils.system_utils import update_environment_variables
from vllm.utils.torch_utils import set_random_seed from vllm.utils.torch_utils import set_random_seed
DEVICE_TYPE = current_platform.device_type
pytestmark = pytest.mark.skipif(not current_platform.is_cuda(), reason="Only test CUDA") pytestmark = pytest.mark.skipif(not current_platform.is_cuda(), reason="Only test CUDA")
FP8_DTYPE = current_platform.fp8_dtype() FP8_DTYPE = current_platform.fp8_dtype()
...@@ -228,7 +230,7 @@ def sequence_parallelism_pass_on_test_model( ...@@ -228,7 +230,7 @@ def sequence_parallelism_pass_on_test_model(
): ):
set_random_seed(0) set_random_seed(0)
device = torch.device(f"cuda:{local_rank}") device = torch.device(f"{DEVICE_TYPE}:{local_rank}")
torch.accelerator.set_device_index(device) torch.accelerator.set_device_index(device)
torch.set_default_device(device) torch.set_default_device(device)
torch.set_default_dtype(dtype) torch.set_default_dtype(dtype)
...@@ -258,7 +260,7 @@ def sequence_parallelism_pass_on_test_model( ...@@ -258,7 +260,7 @@ def sequence_parallelism_pass_on_test_model(
eliminate_noops=True, eliminate_noops=True,
), ),
) # NoOp needed for fusion ) # NoOp needed for fusion
device_config = DeviceConfig(device=torch.device("cuda")) device_config = DeviceConfig(device=torch.device(DEVICE_TYPE))
# this is a fake model name to construct the model config # this is a fake model name to construct the model config
# in the vllm_config, it's not really used. # in the vllm_config, it's not really used.
......
...@@ -41,6 +41,7 @@ from vllm.v1.attention.backend import AttentionMetadata ...@@ -41,6 +41,7 @@ from vllm.v1.attention.backend import AttentionMetadata
from vllm.v1.attention.backends.registry import AttentionBackendEnum from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.kv_cache_interface import AttentionSpec, get_kv_quant_mode from vllm.v1.kv_cache_interface import AttentionSpec, get_kv_quant_mode
DEVICE_TYPE = current_platform.device_type
FP8_DTYPE = current_platform.fp8_dtype() FP8_DTYPE = current_platform.fp8_dtype()
FP4_DTYPE = torch.uint8 FP4_DTYPE = torch.uint8
...@@ -300,7 +301,7 @@ def test_attention_quant_pattern( ...@@ -300,7 +301,7 @@ def test_attention_quant_pattern(
custom_ops_list = custom_ops.split(",") if custom_ops else [] custom_ops_list = custom_ops.split(",") if custom_ops else []
device = torch.device("cuda:0") device = torch.device(f"{DEVICE_TYPE}:0")
torch.set_default_dtype(dtype) torch.set_default_dtype(dtype)
torch.manual_seed(42) torch.manual_seed(42)
......
...@@ -45,6 +45,7 @@ from vllm.v1.kv_cache_interface import MLAAttentionSpec ...@@ -45,6 +45,7 @@ from vllm.v1.kv_cache_interface import MLAAttentionSpec
FP8_DTYPE = current_platform.fp8_dtype() FP8_DTYPE = current_platform.fp8_dtype()
FP4_DTYPE = torch.uint8 FP4_DTYPE = torch.uint8
DEVICE_TYPE = current_platform.device_type
class MLAAttentionQuantPatternModel(torch.nn.Module): class MLAAttentionQuantPatternModel(torch.nn.Module):
...@@ -356,7 +357,7 @@ def test_mla_attention_quant_pattern( ...@@ -356,7 +357,7 @@ def test_mla_attention_quant_pattern(
custom_ops_list = custom_ops.split(",") if custom_ops else [] custom_ops_list = custom_ops.split(",") if custom_ops else []
device = torch.device("cuda:0") device = torch.device(f"{DEVICE_TYPE}:0")
torch.set_default_dtype(dtype) torch.set_default_dtype(dtype)
torch.manual_seed(42) torch.manual_seed(42)
......
...@@ -8,6 +8,9 @@ import vllm ...@@ -8,6 +8,9 @@ import vllm
from tests.compile.backend import TestBackend from tests.compile.backend import TestBackend
from vllm.compilation.passes.utility.noop_elimination import NoOpEliminationPass from vllm.compilation.passes.utility.noop_elimination import NoOpEliminationPass
from vllm.config import CompilationConfig, CompilationMode, PassConfig, VllmConfig from vllm.config import CompilationConfig, CompilationMode, PassConfig, VllmConfig
from vllm.platforms import current_platform
DEVICE_TYPE = current_platform.device_type
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32])
...@@ -17,7 +20,7 @@ from vllm.config import CompilationConfig, CompilationMode, PassConfig, VllmConf ...@@ -17,7 +20,7 @@ from vllm.config import CompilationConfig, CompilationMode, PassConfig, VllmConf
) )
@pytest.mark.parametrize("hidden_size", [64, 4096]) @pytest.mark.parametrize("hidden_size", [64, 4096])
def test_noop_elimination(dtype, num_tokens, hidden_size, buffer_size): def test_noop_elimination(dtype, num_tokens, hidden_size, buffer_size):
torch.set_default_device("cuda") torch.set_default_device(DEVICE_TYPE)
torch.set_default_dtype(dtype) torch.set_default_dtype(dtype)
torch.manual_seed(1) torch.manual_seed(1)
...@@ -88,7 +91,7 @@ def test_non_noop_slice_preserved(): ...@@ -88,7 +91,7 @@ def test_non_noop_slice_preserved():
Regression test for a bug where end=-1 was treated like an inferred Regression test for a bug where end=-1 was treated like an inferred
dimension (reshape semantics) leading to incorrect elimination. dimension (reshape semantics) leading to incorrect elimination.
""" """
torch.set_default_device("cuda") torch.set_default_device(DEVICE_TYPE)
x = torch.randn(16, 16) x = torch.randn(16, 16)
class SliceModel(torch.nn.Module): class SliceModel(torch.nn.Module):
......
...@@ -13,6 +13,9 @@ from vllm.compilation.passes.utility.scatter_split_replace import ( ...@@ -13,6 +13,9 @@ from vllm.compilation.passes.utility.scatter_split_replace import (
from vllm.compilation.passes.utility.split_coalescing import SplitCoalescingPass from vllm.compilation.passes.utility.split_coalescing import SplitCoalescingPass
from vllm.config import CompilationConfig, CompilationMode, VllmConfig from vllm.config import CompilationConfig, CompilationMode, VllmConfig
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from vllm.platforms import current_platform
DEVICE_TYPE = current_platform.device_type
class ScatterSplitReplacementModel(nn.Module): class ScatterSplitReplacementModel(nn.Module):
...@@ -61,7 +64,7 @@ class ScatterSplitReplacementModel(nn.Module): ...@@ -61,7 +64,7 @@ class ScatterSplitReplacementModel(nn.Module):
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_scatter_split_replace(dtype): def test_scatter_split_replace(dtype):
torch.set_default_device("cuda") torch.set_default_device(DEVICE_TYPE)
torch.set_default_dtype(dtype) torch.set_default_dtype(dtype)
torch.manual_seed(0) torch.manual_seed(0)
......
...@@ -8,6 +8,9 @@ import vllm ...@@ -8,6 +8,9 @@ import vllm
from tests.compile.backend import TestBackend from tests.compile.backend import TestBackend
from vllm.compilation.passes.utility.split_coalescing import SplitCoalescingPass from vllm.compilation.passes.utility.split_coalescing import SplitCoalescingPass
from vllm.config import CompilationConfig, CompilationMode, PassConfig, VllmConfig from vllm.config import CompilationConfig, CompilationMode, PassConfig, VllmConfig
from vllm.platforms import current_platform
DEVICE_TYPE = current_platform.device_type
class SplitCoalescingModel(torch.nn.Module): class SplitCoalescingModel(torch.nn.Module):
...@@ -28,7 +31,7 @@ class SplitCoalescingModel(torch.nn.Module): ...@@ -28,7 +31,7 @@ class SplitCoalescingModel(torch.nn.Module):
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_split_coalescing(dtype): def test_split_coalescing(dtype):
torch.set_default_device("cuda") torch.set_default_device(DEVICE_TYPE)
torch.set_default_dtype(dtype) torch.set_default_dtype(dtype)
torch.manual_seed(0) torch.manual_seed(0)
......
...@@ -31,6 +31,8 @@ from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher ...@@ -31,6 +31,8 @@ from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
# This import automatically registers `torch.ops.silly.attention` # This import automatically registers `torch.ops.silly.attention`
from . import silly_attention # noqa: F401 from . import silly_attention # noqa: F401
DEVICE_TYPE = current_platform.device_type
def test_version(): def test_version():
# Test the version comparison logic using the private function # Test the version comparison logic using the private function
...@@ -456,7 +458,7 @@ def test_cached_compilation_config(default_vllm_config): ...@@ -456,7 +458,7 @@ def test_cached_compilation_config(default_vllm_config):
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
dtype = torch.bfloat16 dtype = torch.bfloat16
device = torch.device("cuda:0") device = torch.device(f"{DEVICE_TYPE}:0")
batch_size, num_qo_heads, head_size = 8, 16, 128 batch_size, num_qo_heads, head_size = 8, 16, 128
# access and cache default compilation config # access and cache default compilation config
...@@ -478,7 +480,7 @@ def test_cached_compilation_config(default_vllm_config): ...@@ -478,7 +480,7 @@ def test_cached_compilation_config(default_vllm_config):
query_quant = QuantFP8(static=True, group_shape=GroupShape.PER_TENSOR) query_quant = QuantFP8(static=True, group_shape=GroupShape.PER_TENSOR)
query_quant = torch.compile(query_quant) query_quant = torch.compile(query_quant)
_q_scale = torch.tensor(1.0, dtype=torch.float32, device="cuda") _q_scale = torch.tensor(1.0, dtype=torch.float32, device=DEVICE_TYPE)
query = torch.randn( query = torch.randn(
batch_size, num_qo_heads * head_size, dtype=dtype, device=device batch_size, num_qo_heads * head_size, dtype=dtype, device=device
) )
......
...@@ -15,10 +15,13 @@ from vllm.compilation.backends import ( ...@@ -15,10 +15,13 @@ from vllm.compilation.backends import (
split_graph, split_graph,
) )
from vllm.compilation.passes.fx_utils import find_op_nodes from vllm.compilation.passes.fx_utils import find_op_nodes
from vllm.platforms import current_platform
# This import automatically registers `torch.ops.silly.attention` # This import automatically registers `torch.ops.silly.attention`
from . import silly_attention # noqa: F401 from . import silly_attention # noqa: F401
DEVICE_TYPE = current_platform.device_type
def test_getitem_moved_to_producer_subgraph(): def test_getitem_moved_to_producer_subgraph():
""" """
...@@ -151,7 +154,7 @@ def test_consecutive_ops_in_split(): ...@@ -151,7 +154,7 @@ def test_consecutive_ops_in_split():
final_result = torch.sigmoid(attn_inout) final_result = torch.sigmoid(attn_inout)
return final_result return final_result
torch.set_default_device("cuda") torch.set_default_device(DEVICE_TYPE)
# Create the traced FX graph for the model # Create the traced FX graph for the model
x = torch.randn(8, 4) x = torch.randn(8, 4)
...@@ -329,7 +332,7 @@ def test_builtin_empty_only_partition_is_merged(): ...@@ -329,7 +332,7 @@ def test_builtin_empty_only_partition_is_merged():
"Expected two builtin empty_like nodes in merged non-splitting subgraph" "Expected two builtin empty_like nodes in merged non-splitting subgraph"
) )
x = torch.randn(2, 3, device="cuda") x = torch.randn(2, 3, device=DEVICE_TYPE)
output_original = gm(x) output_original = gm(x)
output_split = split_gm(x) output_split = split_gm(x)
assert torch.allclose(output_original, output_split), "Output mismatch after split" assert torch.allclose(output_original, output_split), "Output mismatch after split"
......
...@@ -16,6 +16,8 @@ from vllm.config.compilation import CompilationMode, CUDAGraphMode ...@@ -16,6 +16,8 @@ from vllm.config.compilation import CompilationMode, CUDAGraphMode
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.platforms import current_platform from vllm.platforms import current_platform
DEVICE_TYPE = current_platform.device_type
@support_torch_compile @support_torch_compile
class RotaryEmbeddingCompileModule(torch.nn.Module): class RotaryEmbeddingCompileModule(torch.nn.Module):
...@@ -45,7 +47,7 @@ def test_rotary_embedding_torch_compile_with_custom_op(monkeypatch): ...@@ -45,7 +47,7 @@ def test_rotary_embedding_torch_compile_with_custom_op(monkeypatch):
monkeypatch.setenv("VLLM_USE_BYTECODE_HOOK", "1") monkeypatch.setenv("VLLM_USE_BYTECODE_HOOK", "1")
monkeypatch.setenv("VLLM_USE_AOT_COMPILE", "0") monkeypatch.setenv("VLLM_USE_AOT_COMPILE", "0")
device = "cuda" device = DEVICE_TYPE
positions = torch.arange(16, device=device) positions = torch.arange(16, device=device)
query = torch.randn(16, 32, device=device, dtype=torch.bfloat16) query = torch.randn(16, 32, device=device, dtype=torch.bfloat16)
key = torch.randn(16, 32, device=device, dtype=torch.bfloat16) key = torch.randn(16, 32, device=device, dtype=torch.bfloat16)
......
...@@ -17,8 +17,10 @@ from vllm.config.compilation import ( ...@@ -17,8 +17,10 @@ from vllm.config.compilation import (
) )
from vllm.config.scheduler import SchedulerConfig from vllm.config.scheduler import SchedulerConfig
from vllm.forward_context import set_forward_context from vllm.forward_context import set_forward_context
from vllm.platforms import current_platform
MLP_SIZE = 64 MLP_SIZE = 64
DEVICE_TYPE = current_platform.device_type
@support_torch_compile @support_torch_compile
...@@ -71,7 +73,7 @@ class TraceStructuredCapture: ...@@ -71,7 +73,7 @@ class TraceStructuredCapture:
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required")
def test_vllm_structured_logging_artifacts(use_fresh_inductor_cache): def test_vllm_structured_logging_artifacts(use_fresh_inductor_cache):
"""Test that all expected vLLM artifacts are logged during compilation.""" """Test that all expected vLLM artifacts are logged during compilation."""
torch.set_default_device("cuda") torch.set_default_device(DEVICE_TYPE)
capture = TraceStructuredCapture() capture = TraceStructuredCapture()
......
...@@ -10,9 +10,10 @@ from vllm.config import LoadConfig, ModelConfig, SpeculativeConfig, VllmConfig ...@@ -10,9 +10,10 @@ from vllm.config import LoadConfig, ModelConfig, SpeculativeConfig, VllmConfig
from vllm.model_executor.models.utils import get_draft_quant_config from vllm.model_executor.models.utils import get_draft_quant_config
from vllm.platforms import current_platform from vllm.platforms import current_platform
DEVICE_TYPE = current_platform.device_type
DEVICES = ( DEVICES = (
[f"cuda:{i}" for i in range(1 if torch.accelerator.device_count() == 1 else 2)] [f"{DEVICE_TYPE}:{i}" for i in range(min(torch.accelerator.device_count(), 2))]
if current_platform.is_cuda_alike() if not current_platform.is_cpu()
else ["cpu"] else ["cpu"]
) )
......
...@@ -7,6 +7,7 @@ from huggingface_hub import snapshot_download ...@@ -7,6 +7,7 @@ from huggingface_hub import snapshot_download
from transformers import AutoConfig, AutoModel, CLIPImageProcessor from transformers import AutoConfig, AutoModel, CLIPImageProcessor
from vllm.distributed import cleanup_dist_env_and_memory from vllm.distributed import cleanup_dist_env_and_memory
from vllm.platforms import current_platform
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
from ....conftest import ImageTestAssets from ....conftest import ImageTestAssets
...@@ -15,6 +16,8 @@ from ....conftest import ImageTestAssets ...@@ -15,6 +16,8 @@ from ....conftest import ImageTestAssets
# dynamic_module and trust_remote_code for hf_runner # dynamic_module and trust_remote_code for hf_runner
DOWNLOAD_PATTERN = ["*.json", "*.py", "*.safetensors", "*.txt", "*.model"] DOWNLOAD_PATTERN = ["*.json", "*.py", "*.safetensors", "*.txt", "*.model"]
DEVICE_TYPE = current_platform.device_type
@torch.inference_mode() @torch.inference_mode()
def run_intern_vit_test( def run_intern_vit_test(
...@@ -39,9 +42,9 @@ def run_intern_vit_test( ...@@ -39,9 +42,9 @@ def run_intern_vit_test(
hf_model = AutoModel.from_pretrained( hf_model = AutoModel.from_pretrained(
model, dtype=torch_dtype, trust_remote_code=True model, dtype=torch_dtype, trust_remote_code=True
).to("cuda") ).to(DEVICE_TYPE)
hf_outputs_per_image = [ hf_outputs_per_image = [
hf_model(pixel_value.to("cuda")).last_hidden_state hf_model(pixel_value.to(DEVICE_TYPE)).last_hidden_state
for pixel_value in pixel_values for pixel_value in pixel_values
] ]
...@@ -53,9 +56,10 @@ def run_intern_vit_test( ...@@ -53,9 +56,10 @@ def run_intern_vit_test(
del hf_model del hf_model
cleanup_dist_env_and_memory() cleanup_dist_env_and_memory()
vllm_model = vllm_model.to("cuda", torch_dtype) vllm_model = vllm_model.to(DEVICE_TYPE, torch_dtype)
vllm_outputs_per_image = [ vllm_outputs_per_image = [
vllm_model(pixel_values=pixel_value.to("cuda")) for pixel_value in pixel_values vllm_model(pixel_values=pixel_value.to(DEVICE_TYPE))
for pixel_value in pixel_values
] ]
del vllm_model del vllm_model
cleanup_dist_env_and_memory() cleanup_dist_env_and_memory()
......
...@@ -8,6 +8,7 @@ from transformers import AutoConfig, AutoModel, CLIPImageProcessor ...@@ -8,6 +8,7 @@ from transformers import AutoConfig, AutoModel, CLIPImageProcessor
from vllm.distributed import cleanup_dist_env_and_memory from vllm.distributed import cleanup_dist_env_and_memory
from vllm.model_executor.models.radio import RadioModel from vllm.model_executor.models.radio import RadioModel
from vllm.platforms import current_platform
from vllm.transformers_utils.configs.radio import RadioConfig from vllm.transformers_utils.configs.radio import RadioConfig
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
...@@ -17,6 +18,8 @@ from ....conftest import ImageTestAssets ...@@ -17,6 +18,8 @@ from ....conftest import ImageTestAssets
# dynamic_module and trust_remote_code for hf_runner # dynamic_module and trust_remote_code for hf_runner
DOWNLOAD_PATTERN = ["*.json", "*.py", "*.safetensors", "*.txt", "*.model"] DOWNLOAD_PATTERN = ["*.json", "*.py", "*.safetensors", "*.txt", "*.model"]
DEVICE_TYPE = current_platform.device_type
@torch.inference_mode() @torch.inference_mode()
def run_radio_test( def run_radio_test(
...@@ -51,7 +54,7 @@ def run_radio_test( ...@@ -51,7 +54,7 @@ def run_radio_test(
config=hf_config, config=hf_config,
dtype=torch_dtype, dtype=torch_dtype,
trust_remote_code=True, trust_remote_code=True,
).to("cuda") ).to(DEVICE_TYPE)
hf_model.eval() hf_model.eval()
# A HF model has image normalization as a part of model's forward # A HF model has image normalization as a part of model's forward
...@@ -62,7 +65,7 @@ def run_radio_test( ...@@ -62,7 +65,7 @@ def run_radio_test(
hf_model.make_preprocessor_external() hf_model.make_preprocessor_external()
hf_outputs_per_image = [ hf_outputs_per_image = [
hf_model(pixel_value.to("cuda")) for pixel_value in pixel_values hf_model(pixel_value.to(DEVICE_TYPE)) for pixel_value in pixel_values
] ]
vllm_config = RadioConfig( vllm_config = RadioConfig(
...@@ -71,10 +74,11 @@ def run_radio_test( ...@@ -71,10 +74,11 @@ def run_radio_test(
) )
vllm_model = RadioModel(vllm_config) vllm_model = RadioModel(vllm_config)
vllm_model.load_weights(hf_model.state_dict()) vllm_model.load_weights(hf_model.state_dict())
vllm_model = vllm_model.to("cuda", torch_dtype) vllm_model = vllm_model.to(DEVICE_TYPE, torch_dtype)
vllm_outputs_per_image = [ vllm_outputs_per_image = [
vllm_model(pixel_values=pixel_value.to("cuda")) for pixel_value in pixel_values vllm_model(pixel_values=pixel_value.to(DEVICE_TYPE))
for pixel_value in pixel_values
] ]
del vllm_model, hf_model del vllm_model, hf_model
cleanup_dist_env_and_memory() cleanup_dist_env_and_memory()
......
...@@ -10,6 +10,8 @@ from vllm.model_executor.models.utils import ( ...@@ -10,6 +10,8 @@ from vllm.model_executor.models.utils import (
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
DEVICE_TYPE = current_platform.device_type
class ModuleWithBatchNorm(torch.nn.Module): class ModuleWithBatchNorm(torch.nn.Module):
def __init__(self): def __init__(self):
...@@ -174,8 +176,12 @@ class raise_if_cuda_sync: ...@@ -174,8 +176,12 @@ class raise_if_cuda_sync:
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda") @pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
def test_merge_multimodal_embeddings_no_sync(): def test_merge_multimodal_embeddings_no_sync():
inputs_embeds = torch.zeros([5, 10], dtype=torch.bfloat16, device="cuda:0") inputs_embeds = torch.zeros(
multimodal_embeddings = [torch.ones([3, 10], dtype=torch.bfloat16, device="cuda:0")] [5, 10], dtype=torch.bfloat16, device=f"{DEVICE_TYPE}:0"
)
multimodal_embeddings = [
torch.ones([3, 10], dtype=torch.bfloat16, device=f"{DEVICE_TYPE}:0")
]
is_multimodal = torch.tensor([True, False, True, True, False], device="cpu") is_multimodal = torch.tensor([True, False, True, True, False], device="cpu")
with raise_if_cuda_sync(): with raise_if_cuda_sync():
_merge_multimodal_embeddings( _merge_multimodal_embeddings(
......
...@@ -24,6 +24,8 @@ from vllm.model_executor.layers.quantization.fp8 import ( ...@@ -24,6 +24,8 @@ from vllm.model_executor.layers.quantization.fp8 import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.platforms import current_platform from vllm.platforms import current_platform
DEVICE_TYPE = current_platform.device_type
MODELS = [ MODELS = [
"neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV", "neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV",
# The checkpoint below was removed from the HF. # The checkpoint below was removed from the HF.
...@@ -314,7 +316,7 @@ def test_scaled_fp8_quant(dtype) -> None: ...@@ -314,7 +316,7 @@ def test_scaled_fp8_quant(dtype) -> None:
# Note that we use a shape % 4 != 0 to cover edge cases, # Note that we use a shape % 4 != 0 to cover edge cases,
# because scaled_fp8_quant is vectorized by 4. # because scaled_fp8_quant is vectorized by 4.
x = (torch.randn(size=(11, 11), device="cuda") * 13).to(dtype) x = (torch.randn(size=(11, 11), device=DEVICE_TYPE) * 13).to(dtype)
# Dynamic quantization # Dynamic quantization
ref_y, inv_scale = ops.scaled_fp8_quant(x, None) ref_y, inv_scale = ops.scaled_fp8_quant(x, None)
...@@ -338,7 +340,9 @@ def test_scaled_fp8_quant(dtype) -> None: ...@@ -338,7 +340,9 @@ def test_scaled_fp8_quant(dtype) -> None:
# non-contiguous input with padding # non-contiguous input with padding
m, n, padded_stride = 975, 512, 576 m, n, padded_stride = 975, 512, 576
padded_tensor = (torch.randn(size=(m, padded_stride), device="cuda") * 13).to(dtype) padded_tensor = (torch.randn(size=(m, padded_stride), device=DEVICE_TYPE) * 13).to(
dtype
)
x_nc = padded_tensor[:, :n] # shape (m, n) with stride (padded_stride, 1) x_nc = padded_tensor[:, :n] # shape (m, n) with stride (padded_stride, 1)
assert not x_nc.is_contiguous() assert not x_nc.is_contiguous()
...@@ -409,7 +413,7 @@ def test_fp8_reloading( ...@@ -409,7 +413,7 @@ def test_fp8_reloading(
# Set model config as model_config.dtype is required in Fp8LinearMethod. # Set model config as model_config.dtype is required in Fp8LinearMethod.
default_vllm_config.model_config = ModelConfig() default_vllm_config.model_config = ModelConfig()
with torch.device("cuda:0"): with torch.device(f"{DEVICE_TYPE}:0"):
config = Fp8Config( config = Fp8Config(
is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized, is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
weight_block_size=weight_block_size, weight_block_size=weight_block_size,
......
...@@ -25,11 +25,13 @@ from vllm.platforms import current_platform ...@@ -25,11 +25,13 @@ from vllm.platforms import current_platform
from vllm.utils.torch_utils import set_random_seed from vllm.utils.torch_utils import set_random_seed
from vllm.v1.kv_cache_interface import KVQuantMode, is_quantized_kv_cache from vllm.v1.kv_cache_interface import KVQuantMode, is_quantized_kv_cache
DEVICE_TYPE = current_platform.device_type
# Skip entire module if no CUDA/ROCm GPU available # Skip entire module if no CUDA/ROCm GPU available
pytestmark = [ pytestmark = [
pytest.mark.skipif( pytest.mark.skipif(
not current_platform.is_cuda_alike(), current_platform.is_cpu(),
reason="Per-token-head KV cache tests require CUDA or ROCm GPU.", reason="Per-token-head KV cache tests require GPU.",
), ),
] ]
...@@ -166,7 +168,7 @@ def test_reshape_and_cache_per_token_head( ...@@ -166,7 +168,7 @@ def test_reshape_and_cache_per_token_head(
) )
set_random_seed(seed) set_random_seed(seed)
torch.set_default_device("cuda") torch.set_default_device(DEVICE_TYPE)
num_blocks = (num_tokens + block_size - 1) // block_size + 4 num_blocks = (num_tokens + block_size - 1) // block_size + 4
...@@ -260,7 +262,7 @@ def test_per_token_head_round_trip_accuracy( ...@@ -260,7 +262,7 @@ def test_per_token_head_round_trip_accuracy(
triton_reshape_and_cache_flash_per_token_head_quant, triton_reshape_and_cache_flash_per_token_head_quant,
) )
torch.set_default_device("cuda") torch.set_default_device(DEVICE_TYPE)
set_random_seed(42) set_random_seed(42)
num_blocks = (num_tokens + block_size - 1) // block_size + 2 num_blocks = (num_tokens + block_size - 1) // block_size + 2
...@@ -323,7 +325,7 @@ def test_per_token_head_negative_slot_skipped(qcfg: QuantConfig): ...@@ -323,7 +325,7 @@ def test_per_token_head_negative_slot_skipped(qcfg: QuantConfig):
triton_reshape_and_cache_flash_per_token_head_quant, triton_reshape_and_cache_flash_per_token_head_quant,
) )
torch.set_default_device("cuda") torch.set_default_device(DEVICE_TYPE)
num_tokens = 4 num_tokens = 4
num_heads = 2 num_heads = 2
head_size = 64 head_size = 64
...@@ -430,7 +432,7 @@ def test_triton_unified_attention_per_token_head_scale( ...@@ -430,7 +432,7 @@ def test_triton_unified_attention_per_token_head_scale(
from vllm.utils.math_utils import next_power_of_2 from vllm.utils.math_utils import next_power_of_2
from vllm.v1.attention.ops.triton_unified_attention import unified_attention from vllm.v1.attention.ops.triton_unified_attention import unified_attention
torch.set_default_device("cuda") torch.set_default_device(DEVICE_TYPE)
set_random_seed(0) set_random_seed(0)
num_seqs = len(seq_lens) num_seqs = len(seq_lens)
......
...@@ -36,6 +36,8 @@ QUARK_MXFP4_AVAILABLE = find_spec("quark") is not None and version.parse( ...@@ -36,6 +36,8 @@ QUARK_MXFP4_AVAILABLE = find_spec("quark") is not None and version.parse(
importlib.metadata.version("amd-quark") importlib.metadata.version("amd-quark")
) >= version.parse(QUARK_MXFP4_MIN_VERSION) ) >= version.parse(QUARK_MXFP4_MIN_VERSION)
DEVICE_TYPE = current_platform.device_type
if QUARK_MXFP4_AVAILABLE: if QUARK_MXFP4_AVAILABLE:
from quark.torch.export.nn.modules.realquantizer import StaticScaledRealQuantizer from quark.torch.export.nn.modules.realquantizer import StaticScaledRealQuantizer
from quark.torch.kernel import mx as mx_kernel from quark.torch.kernel import mx as mx_kernel
...@@ -309,7 +311,7 @@ def test_mxfp4_fused_qdq_match_quark(float_dtype: torch.dtype, scalings: list[in ...@@ -309,7 +311,7 @@ def test_mxfp4_fused_qdq_match_quark(float_dtype: torch.dtype, scalings: list[in
torch.manual_seed(0) torch.manual_seed(0)
hidden_size = 64 * 32 hidden_size = 64 * 32
inp = (torch.rand(1, hidden_size, dtype=float_dtype, device="cuda") - 0.5) * 2 inp = (torch.rand(1, hidden_size, dtype=float_dtype, device=DEVICE_TYPE) - 0.5) * 2
for i in range(hidden_size // 32): for i in range(hidden_size // 32):
inp[:, i * 32 : (i + 1) * 32] = ( inp[:, i * 32 : (i + 1) * 32] = (
inp[:, i * 32 : (i + 1) * 32] * scalings[i % len(scalings)] inp[:, i * 32 : (i + 1) * 32] * scalings[i % len(scalings)]
...@@ -353,15 +355,15 @@ def test_mxfp4_dequant_kernel_match_quark( ...@@ -353,15 +355,15 @@ def test_mxfp4_dequant_kernel_match_quark(
reorder=False, reorder=False,
real_quantized=True, real_quantized=True,
float_dtype=float_dtype, float_dtype=float_dtype,
device="cuda", device=DEVICE_TYPE,
) )
observer = qspec.observer_cls(qspec, device="cuda") observer = qspec.observer_cls(qspec, device=DEVICE_TYPE)
hidden_size = 512 hidden_size = 512
shape = (11008, hidden_size) shape = (11008, hidden_size)
w = (torch.rand(shape, device="cuda", dtype=float_dtype) - 0.5) * 2 w = (torch.rand(shape, device=DEVICE_TYPE, dtype=float_dtype) - 0.5) * 2
# Make it so that different groups have different scales. # Make it so that different groups have different scales.
for i in range(hidden_size // 32): for i in range(hidden_size // 32):
...@@ -373,7 +375,7 @@ def test_mxfp4_dequant_kernel_match_quark( ...@@ -373,7 +375,7 @@ def test_mxfp4_dequant_kernel_match_quark(
scale, _ = observer._calculate_qparams() scale, _ = observer._calculate_qparams()
weight_quantizer.scale = scale weight_quantizer.scale = scale
w_mxfp4 = weight_quantizer.to_real_quantize_params(w).to("cuda") w_mxfp4 = weight_quantizer.to_real_quantize_params(w).to(DEVICE_TYPE)
weight_quantizer.maybe_convert_and_transpose_scale() weight_quantizer.maybe_convert_and_transpose_scale()
scale = weight_quantizer.scale scale = weight_quantizer.scale
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment