Unverified Commit 6ffa3f31 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[CI/Build] Avoid CUDA initialization (#8534)

parent e3515729
...@@ -5,6 +5,7 @@ from einops import rearrange, repeat ...@@ -5,6 +5,7 @@ from einops import rearrange, repeat
from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
selective_scan_fn, selective_state_update) selective_scan_fn, selective_state_update)
from vllm.utils import seed_everything
def selective_state_update_ref(state, def selective_state_update_ref(state,
...@@ -186,7 +187,7 @@ def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D, ...@@ -186,7 +187,7 @@ def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D,
rtolw = max(rtolw, rtol) rtolw = max(rtolw, rtol)
atolw = max(atolw, atol) atolw = max(atolw, atol)
# set seed # set seed
torch.random.manual_seed(0) seed_everything(0)
batch_size = 2 batch_size = 2
dim = 4 dim = 4
dstate = 8 dstate = 8
...@@ -287,7 +288,7 @@ def test_selective_state_update(dim, dstate, has_z, itype): ...@@ -287,7 +288,7 @@ def test_selective_state_update(dim, dstate, has_z, itype):
if torch.version.hip: if torch.version.hip:
atol *= 2 atol *= 2
# set seed # set seed
torch.random.manual_seed(0) seed_everything(0)
batch_size = 1 batch_size = 1
state = torch.randn(batch_size, dim, dstate, dtype=itype, device=device) state = torch.randn(batch_size, dim, dstate, dtype=itype, device=device)
x = torch.randn(batch_size, dim, device=device, dtype=itype) x = torch.randn(batch_size, dim, device=device, dtype=itype)
......
...@@ -18,6 +18,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( ...@@ -18,6 +18,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
marlin_quantize) marlin_quantize)
from vllm.model_executor.models.mixtral import MixtralMoE from vllm.model_executor.models.mixtral import MixtralMoE
from vllm.scalar_type import scalar_types from vllm.scalar_type import scalar_types
from vllm.utils import seed_everything
def torch_moe(a, w1, w2, score, topk): def torch_moe(a, w1, w2, score, topk):
...@@ -151,7 +152,7 @@ def test_fused_marlin_moe( ...@@ -151,7 +152,7 @@ def test_fused_marlin_moe(
act_order: bool, act_order: bool,
num_bits: int, num_bits: int,
): ):
torch.manual_seed(7) seed_everything(7)
if topk > e: if topk > e:
return return
......
...@@ -5,6 +5,7 @@ import pytest ...@@ -5,6 +5,7 @@ import pytest
import torch import torch
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.utils import seed_everything
from .allclose_default import get_default_atol, get_default_rtol from .allclose_default import get_default_atol, get_default_rtol
...@@ -46,9 +47,8 @@ def test_rotary_embedding( ...@@ -46,9 +47,8 @@ def test_rotary_embedding(
) -> None: ) -> None:
if rotary_dim is None: if rotary_dim is None:
rotary_dim = head_size rotary_dim = head_size
torch.random.manual_seed(seed)
if torch.cuda.is_available(): seed_everything(seed)
torch.cuda.manual_seed(seed)
torch.set_default_device(device) torch.set_default_device(device)
if rotary_dim is None: if rotary_dim is None:
rotary_dim = head_size rotary_dim = head_size
...@@ -100,9 +100,7 @@ def test_batched_rotary_embedding( ...@@ -100,9 +100,7 @@ def test_batched_rotary_embedding(
max_position: int = 8192, max_position: int = 8192,
base: int = 10000, base: int = 10000,
) -> None: ) -> None:
torch.random.manual_seed(seed) seed_everything(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.set_default_device(device) torch.set_default_device(device)
if rotary_dim is None: if rotary_dim is None:
rotary_dim = head_size rotary_dim = head_size
...@@ -162,9 +160,7 @@ def test_batched_rotary_embedding_multi_lora( ...@@ -162,9 +160,7 @@ def test_batched_rotary_embedding_multi_lora(
max_position: int = 8192, max_position: int = 8192,
base: int = 10000, base: int = 10000,
) -> None: ) -> None:
torch.random.manual_seed(seed) seed_everything(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.set_default_device(device) torch.set_default_device(device)
if rotary_dim is None: if rotary_dim is None:
rotary_dim = head_size rotary_dim = head_size
......
...@@ -9,7 +9,7 @@ from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask ...@@ -9,7 +9,7 @@ from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask
from vllm.attention.backends.xformers import _make_alibi_bias from vllm.attention.backends.xformers import _make_alibi_bias
from vllm.attention.ops.prefix_prefill import context_attention_fwd from vllm.attention.ops.prefix_prefill import context_attention_fwd
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, seed_everything
NUM_HEADS = [64] NUM_HEADS = [64]
NUM_QUERIES_PER_KV = [1, 8, 64] NUM_QUERIES_PER_KV = [1, 8, 64]
...@@ -39,10 +39,7 @@ def test_contexted_kv_attention( ...@@ -39,10 +39,7 @@ def test_contexted_kv_attention(
kv_cache_dtype: str, kv_cache_dtype: str,
device: str, device: str,
) -> None: ) -> None:
random.seed(0) seed_everything(0)
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed(0)
torch.set_default_device(device) torch.set_default_device(device)
# Need this, otherwise when we capture the graph the process # Need this, otherwise when we capture the graph the process
...@@ -237,10 +234,7 @@ def test_contexted_kv_attention_alibi( ...@@ -237,10 +234,7 @@ def test_contexted_kv_attention_alibi(
kv_cache_dtype: str, kv_cache_dtype: str,
device: str, device: str,
) -> None: ) -> None:
random.seed(0) seed_everything(0)
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed(0)
torch.set_default_device(device) torch.set_default_device(device)
# Need this, otherwise when we capture the graph the process # Need this, otherwise when we capture the graph the process
......
...@@ -39,6 +39,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope ...@@ -39,6 +39,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding, get_masked_input_and_mask) ParallelLMHead, VocabParallelEmbedding, get_masked_input_and_mask)
from vllm.model_executor.utils import set_random_seed from vllm.model_executor.utils import set_random_seed
from vllm.utils import seed_everything
from .utils import DummyLoRAManager from .utils import DummyLoRAManager
...@@ -922,9 +923,7 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device, ...@@ -922,9 +923,7 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device,
seq_len) -> None: seq_len) -> None:
dtype = torch.float16 dtype = torch.float16
seed = 0 seed = 0
torch.random.manual_seed(seed) seed_everything(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.set_default_device(device) torch.set_default_device(device)
punica_wrapper = PunicaWrapper(8192, 256, device) punica_wrapper = PunicaWrapper(8192, 256, device)
max_loras = 8 max_loras = 8
......
...@@ -4,7 +4,6 @@ hidden_sizes included in the LoRA models currently supported by vLLM. It tests ...@@ -4,7 +4,6 @@ hidden_sizes included in the LoRA models currently supported by vLLM. It tests
whether the corresponding Triton kernel can run normally when tensor parallelism whether the corresponding Triton kernel can run normally when tensor parallelism
is set to [1, 2, 4, 8, 16, 32, 64]. is set to [1, 2, 4, 8, 16, 32, 64].
""" """
import random
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
...@@ -17,6 +16,7 @@ from vllm.lora.ops.sgmv_expand import sgmv_expand ...@@ -17,6 +16,7 @@ from vllm.lora.ops.sgmv_expand import sgmv_expand
from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice
from vllm.lora.ops.sgmv_shrink import sgmv_shrink from vllm.lora.ops.sgmv_shrink import sgmv_shrink
from vllm.triton_utils.libentry import LibEntry from vllm.triton_utils.libentry import LibEntry
from vllm.utils import seed_everything
from .utils import (generate_data, generate_data_for_expand_nslices, from .utils import (generate_data, generate_data_for_expand_nslices,
ref_torch_groupgemm) ref_torch_groupgemm)
...@@ -145,11 +145,8 @@ def test_punica_sgmv( ...@@ -145,11 +145,8 @@ def test_punica_sgmv(
seed: int, seed: int,
device: str, device: str,
): ):
random.seed(seed)
torch.set_default_device(device) torch.set_default_device(device)
torch.random.manual_seed(seed) seed_everything(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
seq_length = 128 seq_length = 128
( (
...@@ -238,11 +235,8 @@ def test_punica_bgmv( ...@@ -238,11 +235,8 @@ def test_punica_bgmv(
from vllm.lora.ops.bgmv_expand import _bgmv_expand_kernel from vllm.lora.ops.bgmv_expand import _bgmv_expand_kernel
from vllm.lora.ops.bgmv_shrink import _bgmv_shrink_kernel from vllm.lora.ops.bgmv_shrink import _bgmv_shrink_kernel
random.seed(seed)
torch.set_default_device(device) torch.set_default_device(device)
torch.random.manual_seed(seed) seed_everything(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
seq_length = 1 seq_length = 1
( (
...@@ -329,11 +323,9 @@ def test_punica_expand_nslices( ...@@ -329,11 +323,9 @@ def test_punica_expand_nslices(
): ):
from vllm.lora.ops.bgmv_expand_slice import _bgmv_expand_slice_kernel from vllm.lora.ops.bgmv_expand_slice import _bgmv_expand_slice_kernel
random.seed(seed)
torch.set_default_device(device) torch.set_default_device(device)
torch.random.manual_seed(seed) seed_everything(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
seq_length = 128 if op_type == "sgmv" else 1 seq_length = 128 if op_type == "sgmv" else 1
( (
inputs_tensor, inputs_tensor,
......
...@@ -3,7 +3,6 @@ This script is mainly used to test whether trtion kernels can run normally ...@@ -3,7 +3,6 @@ This script is mainly used to test whether trtion kernels can run normally
under different conditions, including various batches, numbers of LoRA , and under different conditions, including various batches, numbers of LoRA , and
maximum ranks. maximum ranks.
""" """
import random
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
...@@ -16,6 +15,7 @@ from vllm.lora.ops.sgmv_expand import sgmv_expand ...@@ -16,6 +15,7 @@ from vllm.lora.ops.sgmv_expand import sgmv_expand
from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice
from vllm.lora.ops.sgmv_shrink import sgmv_shrink from vllm.lora.ops.sgmv_shrink import sgmv_shrink
from vllm.triton_utils.libentry import LibEntry from vllm.triton_utils.libentry import LibEntry
from vllm.utils import seed_everything
from .utils import (generate_data, generate_data_for_expand_nslices, from .utils import (generate_data, generate_data_for_expand_nslices,
ref_torch_groupgemm) ref_torch_groupgemm)
...@@ -60,11 +60,8 @@ def test_punica_sgmv( ...@@ -60,11 +60,8 @@ def test_punica_sgmv(
seed: int, seed: int,
device: str, device: str,
): ):
random.seed(seed)
torch.set_default_device(device) torch.set_default_device(device)
torch.random.manual_seed(seed) seed_everything(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
seq_length = 128 seq_length = 128
( (
...@@ -153,11 +150,8 @@ def test_punica_bgmv( ...@@ -153,11 +150,8 @@ def test_punica_bgmv(
from vllm.lora.ops.bgmv_expand import _bgmv_expand_kernel from vllm.lora.ops.bgmv_expand import _bgmv_expand_kernel
from vllm.lora.ops.bgmv_shrink import _bgmv_shrink_kernel from vllm.lora.ops.bgmv_shrink import _bgmv_shrink_kernel
random.seed(seed)
torch.set_default_device(device) torch.set_default_device(device)
torch.random.manual_seed(seed) seed_everything(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
seq_length = 1 seq_length = 1
( (
...@@ -244,11 +238,9 @@ def test_punica_expand_nslices( ...@@ -244,11 +238,9 @@ def test_punica_expand_nslices(
): ):
from vllm.lora.ops.bgmv_expand_slice import _bgmv_expand_slice_kernel from vllm.lora.ops.bgmv_expand_slice import _bgmv_expand_slice_kernel
random.seed(seed)
torch.set_default_device(device) torch.set_default_device(device)
torch.random.manual_seed(seed) seed_everything(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
seq_length = 128 if op_type == "sgmv" else 1 seq_length = 128 if op_type == "sgmv" else 1
( (
inputs_tensor, inputs_tensor,
......
...@@ -2,23 +2,18 @@ ...@@ -2,23 +2,18 @@
Run `pytest tests/models/test_granite.py`. Run `pytest tests/models/test_granite.py`.
""" """
import importlib.metadata
import pytest import pytest
import transformers
from ...utils import check_logprobs_close from ...utils import check_logprobs_close
TRANSFORMERS_VERSION = tuple(
map(int,
importlib.metadata.version("transformers").split(".")))
MODELS = [ MODELS = [
"ibm/PowerLM-3b", "ibm/PowerLM-3b",
] ]
# GraniteForCausalLM will be in transformers >= 4.45 # GraniteForCausalLM will be in transformers >= 4.45
@pytest.mark.skipif(TRANSFORMERS_VERSION < (4, 45), @pytest.mark.skipif(transformers.__version__ < "4.45",
reason="granite model test requires transformers >= 4.45") reason="granite model test requires transformers >= 4.45")
@pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("dtype", ["bfloat16"])
......
...@@ -86,9 +86,7 @@ def test_load_fp16_model(vllm_runner, kv_cache_dtype: str, force_marlin: bool, ...@@ -86,9 +86,7 @@ def test_load_fp16_model(vllm_runner, kv_cache_dtype: str, force_marlin: bool,
assert attn._k_scale == 1.0 assert attn._k_scale == 1.0
assert attn._v_scale == 1.0 assert attn._v_scale == 1.0
capability = current_platform.get_device_capability() if current_platform.has_device_capability(89) and not force_marlin:
capability = capability[0] * 10 + capability[1]
if capability >= 89 and not force_marlin:
# For GPUs with hardware support, we keep weights in fp8 # For GPUs with hardware support, we keep weights in fp8
assert fc1.weight.dtype == torch.float8_e4m3fn assert fc1.weight.dtype == torch.float8_e4m3fn
else: else:
......
...@@ -8,6 +8,8 @@ def is_quant_method_supported(quant_method: str) -> bool: ...@@ -8,6 +8,8 @@ def is_quant_method_supported(quant_method: str) -> bool:
return False return False
capability = current_platform.get_device_capability() capability = current_platform.get_device_capability()
capability = capability[0] * 10 + capability[1] assert capability is not None
return (capability >=
QUANTIZATION_METHODS[quant_method].get_min_capability()) min_capability = QUANTIZATION_METHODS[quant_method].get_min_capability()
return capability.to_int() >= min_capability
...@@ -13,6 +13,7 @@ from vllm.attention.backends.utils import (CommonAttentionState, ...@@ -13,6 +13,7 @@ from vllm.attention.backends.utils import (CommonAttentionState,
from vllm.attention.ops.paged_attn import (PagedAttention, from vllm.attention.ops.paged_attn import (PagedAttention,
PagedAttentionMetadata) PagedAttentionMetadata)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -299,7 +300,7 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -299,7 +300,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
else: else:
# if not using triton, navi3x/navi21/navi10 do not use flash-attn # if not using triton, navi3x/navi21/navi10 do not use flash-attn
# either # either
if torch.cuda.get_device_capability()[0] != 9: if not current_platform.has_device_capability(90):
self.use_naive_attn = True self.use_naive_attn = True
else: else:
try: try:
......
...@@ -8,8 +8,7 @@ from vllm.utils import is_cpu, is_hip ...@@ -8,8 +8,7 @@ from vllm.utils import is_cpu, is_hip
from .utils import (dense_to_crow_col, get_head_sliding_step, from .utils import (dense_to_crow_col, get_head_sliding_step,
get_sparse_attn_mask) get_sparse_attn_mask)
IS_COMPUTE_8_OR_ABOVE = (torch.cuda.is_available() IS_COMPUTE_8_OR_ABOVE = current_platform.has_device_capability(80)
and current_platform.get_device_capability()[0] >= 8)
if IS_COMPUTE_8_OR_ABOVE: if IS_COMPUTE_8_OR_ABOVE:
from .blocksparse_attention_kernel import blocksparse_flash_attn_varlen_fwd from .blocksparse_attention_kernel import blocksparse_flash_attn_varlen_fwd
...@@ -36,7 +35,7 @@ class LocalStridedBlockSparseAttn(torch.nn.Module): ...@@ -36,7 +35,7 @@ class LocalStridedBlockSparseAttn(torch.nn.Module):
use_spda = is_hip() or is_cpu() or not \ use_spda = is_hip() or is_cpu() or not \
IS_COMPUTE_8_OR_ABOVE IS_COMPUTE_8_OR_ABOVE
device = device or (torch.cuda.current_device() device = device or (torch.cuda.current_device()
if torch.cuda.is_available() else "cpu") if current_platform.is_cuda_alike() else "cpu")
device = torch.device(device) device = torch.device(device)
# NOTE: vllm CPU backend support BF16 instead of FP16. # NOTE: vllm CPU backend support BF16 instead of FP16.
dtype = dtype or (torch.bfloat16 if IS_COMPUTE_8_OR_ABOVE dtype = dtype or (torch.bfloat16 if IS_COMPUTE_8_OR_ABOVE
......
...@@ -709,8 +709,7 @@ if triton.__version__ >= "2.1.0": ...@@ -709,8 +709,7 @@ if triton.__version__ >= "2.1.0":
alibi_slopes=None, alibi_slopes=None,
sliding_window=None): sliding_window=None):
cap = current_platform.get_device_capability() BLOCK = 128 if current_platform.has_device_capability(80) else 64
BLOCK = 128 if cap[0] >= 8 else 64
NUM_WARPS = 8 NUM_WARPS = 8
# need to reduce num. blocks when using fp32 # need to reduce num. blocks when using fp32
......
...@@ -203,7 +203,7 @@ def which_attn_to_use( ...@@ -203,7 +203,7 @@ def which_attn_to_use(
selected_backend = (_Backend.ROCM_FLASH if selected_backend selected_backend = (_Backend.ROCM_FLASH if selected_backend
== _Backend.FLASH_ATTN else selected_backend) == _Backend.FLASH_ATTN else selected_backend)
if selected_backend == _Backend.ROCM_FLASH: if selected_backend == _Backend.ROCM_FLASH:
if current_platform.get_device_capability()[0] != 9: if not current_platform.has_device_capability(90):
# not Instinct series GPUs. # not Instinct series GPUs.
logger.info("flash_attn is not supported on NAVI GPUs.") logger.info("flash_attn is not supported on NAVI GPUs.")
else: else:
...@@ -212,7 +212,7 @@ def which_attn_to_use( ...@@ -212,7 +212,7 @@ def which_attn_to_use(
# FlashAttn in NVIDIA GPUs. # FlashAttn in NVIDIA GPUs.
if selected_backend == _Backend.FLASH_ATTN: if selected_backend == _Backend.FLASH_ATTN:
if current_platform.get_device_capability()[0] < 8: if not current_platform.has_device_capability(80):
# Volta and Turing NVIDIA GPUs. # Volta and Turing NVIDIA GPUs.
logger.info( logger.info(
"Cannot use FlashAttention-2 backend for Volta and Turing " "Cannot use FlashAttention-2 backend for Volta and Turing "
......
...@@ -17,7 +17,7 @@ from vllm.transformers_utils.config import (ConfigFormat, get_config, ...@@ -17,7 +17,7 @@ from vllm.transformers_utils.config import (ConfigFormat, get_config,
get_hf_image_processor_config, get_hf_image_processor_config,
get_hf_text_config) get_hf_text_config)
from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory, from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory,
is_cpu, is_hip, is_neuron, is_openvino, is_xpu, is_hip, is_neuron, is_openvino, is_xpu,
print_warning_once) print_warning_once)
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -1035,20 +1035,20 @@ class DeviceConfig: ...@@ -1035,20 +1035,20 @@ class DeviceConfig:
def __init__(self, device: str = "auto") -> None: def __init__(self, device: str = "auto") -> None:
if device == "auto": if device == "auto":
# Automated device type detection # Automated device type detection
if is_neuron(): if current_platform.is_cuda_alike():
self.device_type = "cuda"
elif is_neuron():
self.device_type = "neuron" self.device_type = "neuron"
elif is_openvino(): elif is_openvino():
self.device_type = "openvino" self.device_type = "openvino"
elif current_platform.is_tpu(): elif current_platform.is_tpu():
self.device_type = "tpu" self.device_type = "tpu"
elif is_cpu(): elif current_platform.is_cpu():
self.device_type = "cpu" self.device_type = "cpu"
elif is_xpu(): elif is_xpu():
self.device_type = "xpu" self.device_type = "xpu"
else: else:
# We don't call torch.cuda.is_available() here to raise RuntimeError("Failed to infer device type")
# avoid initializing CUDA before workers are forked
self.device_type = "cuda"
else: else:
# Device type is assigned explicitly # Device type is assigned explicitly
self.device_type = device self.device_type = device
......
...@@ -35,6 +35,7 @@ from torch.distributed import Backend, ProcessGroup ...@@ -35,6 +35,7 @@ from torch.distributed import Backend, ProcessGroup
import vllm.envs as envs import vllm.envs as envs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform
@dataclass @dataclass
...@@ -191,7 +192,7 @@ class GroupCoordinator: ...@@ -191,7 +192,7 @@ class GroupCoordinator:
assert self.cpu_group is not None assert self.cpu_group is not None
assert self.device_group is not None assert self.device_group is not None
if torch.cuda.is_available(): if current_platform.is_cuda_alike():
self.device = torch.device(f"cuda:{local_rank}") self.device = torch.device(f"cuda:{local_rank}")
else: else:
self.device = torch.device("cpu") self.device = torch.device("cpu")
......
...@@ -60,6 +60,7 @@ if TYPE_CHECKING: ...@@ -60,6 +60,7 @@ if TYPE_CHECKING:
VLLM_RPC_GET_DATA_TIMEOUT_MS: int = 5000 VLLM_RPC_GET_DATA_TIMEOUT_MS: int = 5000
VLLM_PLUGINS: Optional[List[str]] = None VLLM_PLUGINS: Optional[List[str]] = None
VLLM_TORCH_PROFILER_DIR: Optional[str] = None VLLM_TORCH_PROFILER_DIR: Optional[str] = None
VLLM_USE_TRITON_AWQ: bool = False
VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False
......
...@@ -116,10 +116,10 @@ class CompressedTensorsConfig(QuantizationConfig): ...@@ -116,10 +116,10 @@ class CompressedTensorsConfig(QuantizationConfig):
def _check_scheme_supported(self, def _check_scheme_supported(self,
min_capability: int, min_capability: int,
error: bool = True) -> bool: error: bool = True) -> bool:
capability = current_platform.get_device_capability() # type: ignore capability_tuple = current_platform.get_device_capability()
if capability is not None: if capability_tuple is not None:
capability = capability[0] * 10 + capability[1] capability = capability_tuple.to_int()
supported = capability >= min_capability supported = capability >= min_capability
if error and not supported: if error and not supported:
raise RuntimeError( raise RuntimeError(
......
...@@ -32,9 +32,7 @@ class FBGEMMFp8Config(QuantizationConfig): ...@@ -32,9 +32,7 @@ class FBGEMMFp8Config(QuantizationConfig):
# For GPUs that lack FP8 hardware support, we can leverage the Marlin # For GPUs that lack FP8 hardware support, we can leverage the Marlin
# kernel for fast weight-only FP8 quantization # kernel for fast weight-only FP8 quantization
capability = current_platform.get_device_capability() self.use_marlin = not current_platform.has_device_capability(89)
capability = capability[0] * 10 + capability[1]
self.use_marlin = capability < 89
@classmethod @classmethod
def get_name(cls) -> str: def get_name(cls) -> str:
......
...@@ -120,9 +120,8 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -120,9 +120,8 @@ class Fp8LinearMethod(LinearMethodBase):
# For GPUs that lack FP8 hardware support, we can leverage the Marlin # For GPUs that lack FP8 hardware support, we can leverage the Marlin
# kernel for fast weight-only FP8 quantization # kernel for fast weight-only FP8 quantization
capability = current_platform.get_device_capability() self.use_marlin = (not current_platform.has_device_capability(89)
capability = capability[0] * 10 + capability[1] or envs.VLLM_TEST_FORCE_FP8_MARLIN)
self.use_marlin = capability < 89 or envs.VLLM_TEST_FORCE_FP8_MARLIN
# Disable marlin for rocm # Disable marlin for rocm
if is_hip(): if is_hip():
self.use_marlin = False self.use_marlin = False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment