Unverified Commit 4e2d95e3 authored by wangshuai09's avatar wangshuai09 Committed by GitHub
Browse files

[Hardware][ROCM] using current_platform.is_rocm (#9642)


Signed-off-by: default avatarwangshuai09 <391746016@qq.com>
parent 34a99416
...@@ -11,7 +11,7 @@ from unittest.mock import patch ...@@ -11,7 +11,7 @@ from unittest.mock import patch
import pytest import pytest
from vllm import LLM from vllm import LLM
from vllm.utils import is_hip from vllm.platforms import current_platform
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
from ..models.utils import check_outputs_equal from ..models.utils import check_outputs_equal
...@@ -51,7 +51,7 @@ def test_models( ...@@ -51,7 +51,7 @@ def test_models(
enforce_eager: bool, enforce_eager: bool,
) -> None: ) -> None:
if backend == "FLASHINFER" and is_hip(): if backend == "FLASHINFER" and current_platform.is_rocm():
pytest.skip("Flashinfer does not support ROCm/HIP.") pytest.skip("Flashinfer does not support ROCm/HIP.")
os.environ["VLLM_ATTENTION_BACKEND"] = backend os.environ["VLLM_ATTENTION_BACKEND"] = backend
......
...@@ -5,7 +5,7 @@ import torch ...@@ -5,7 +5,7 @@ import torch
from tests.quantization.utils import is_quant_method_supported from tests.quantization.utils import is_quant_method_supported
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.compilation.levels import CompilationLevel from vllm.compilation.levels import CompilationLevel
from vllm.utils import is_hip from vllm.platforms import current_platform
TEST_MODELS = [ TEST_MODELS = [
("facebook/opt-125m", {}), ("facebook/opt-125m", {}),
...@@ -55,7 +55,7 @@ if is_quant_method_supported("marlin"): ...@@ -55,7 +55,7 @@ if is_quant_method_supported("marlin"):
"quantization": "marlin" "quantization": "marlin"
})) }))
if not is_hip() and is_quant_method_supported("awq"): if not current_platform.is_rocm() and is_quant_method_supported("awq"):
TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", { TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", {
"quantization": "AWQ" "quantization": "AWQ"
})) }))
......
...@@ -2,12 +2,13 @@ from typing import Optional, Tuple, Union ...@@ -2,12 +2,13 @@ from typing import Optional, Tuple, Union
import torch import torch
from vllm.utils import is_hip from vllm.platforms import current_platform
# Using the default value (240.0) from pytorch will cause accuracy # Using the default value (240.0) from pytorch will cause accuracy
# issue on dynamic quantization models. Here use 224.0 for rocm. # issue on dynamic quantization models. Here use 224.0 for rocm.
ROCM_FP8_MAX = 224.0 ROCM_FP8_MAX = 224.0
FP8_DTYPE = torch.float8_e4m3fnuz if is_hip() else torch.float8_e4m3fn FP8_DTYPE = torch.float8_e4m3fnuz if current_platform.is_rocm() \
else torch.float8_e4m3fn
def as_float32_tensor(x: Union[float, torch.tensor]) -> torch.tensor: def as_float32_tensor(x: Union[float, torch.tensor]) -> torch.tensor:
...@@ -24,8 +25,10 @@ def ref_dynamic_per_token_quant(x: torch.tensor, ...@@ -24,8 +25,10 @@ def ref_dynamic_per_token_quant(x: torch.tensor,
qtype_traits = torch.iinfo(quant_dtype) if quant_dtype == torch.int8 \ qtype_traits = torch.iinfo(quant_dtype) if quant_dtype == torch.int8 \
else torch.finfo(quant_dtype) else torch.finfo(quant_dtype)
qtype_traits_max = ROCM_FP8_MAX if is_hip() else qtype_traits.max qtype_traits_max = ROCM_FP8_MAX if current_platform.is_rocm() \
qtype_traits_min = -ROCM_FP8_MAX if is_hip() else qtype_traits.min else qtype_traits.max
qtype_traits_min = -ROCM_FP8_MAX if current_platform.is_rocm() \
else qtype_traits.min
qtype_max = as_float32_tensor(qtype_traits_max) qtype_max = as_float32_tensor(qtype_traits_max)
s_1 = as_float32_tensor(1.0) s_1 = as_float32_tensor(1.0)
s_512 = as_float32_tensor(512.0) s_512 = as_float32_tensor(512.0)
...@@ -66,8 +69,10 @@ def ref_dynamic_per_tensor_fp8_quant(x: torch.tensor) \ ...@@ -66,8 +69,10 @@ def ref_dynamic_per_tensor_fp8_quant(x: torch.tensor) \
-> Tuple[torch.tensor, torch.tensor]: -> Tuple[torch.tensor, torch.tensor]:
fp8_traits = torch.finfo(FP8_DTYPE) fp8_traits = torch.finfo(FP8_DTYPE)
fp8_traits_max = ROCM_FP8_MAX if is_hip() else fp8_traits.max fp8_traits_max = ROCM_FP8_MAX if current_platform.is_rocm() \
fp8_traits_min = -ROCM_FP8_MAX if is_hip() else fp8_traits.min else fp8_traits.max
fp8_traits_min = -ROCM_FP8_MAX if current_platform.is_rocm() \
else fp8_traits.min
fp8_max = as_float32_tensor(fp8_traits_max) fp8_max = as_float32_tensor(fp8_traits_max)
one = as_float32_tensor(1.0) one = as_float32_tensor(1.0)
......
...@@ -6,11 +6,12 @@ import torch ...@@ -6,11 +6,12 @@ import torch
from tests.kernels.utils import opcheck from tests.kernels.utils import opcheck
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.utils import get_max_shared_memory_bytes, is_hip, seed_everything from vllm.platforms import current_platform
from vllm.utils import get_max_shared_memory_bytes, seed_everything
from .allclose_default import get_default_atol, get_default_rtol from .allclose_default import get_default_atol, get_default_rtol
if not is_hip(): if not current_platform.is_rocm():
from xformers import ops as xops from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
...@@ -23,8 +24,9 @@ MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512 ...@@ -23,8 +24,9 @@ MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512
NUM_BLOCKS = 4321 # Arbitrary values for testing NUM_BLOCKS = 4321 # Arbitrary values for testing
PARTITION_SIZE = 512 PARTITION_SIZE = 512
# flshattF and tritonflashattF supported: {torch.float16, torch.bfloat16} # flshattF and tritonflashattF supported: {torch.float16, torch.bfloat16}
DTYPES = [torch.half, torch.bfloat16, torch.float DTYPES = [
] if not is_hip() else [torch.half, torch.bfloat16] torch.half, torch.bfloat16, torch.float
] if not current_platform.is_rocm() else [torch.half, torch.bfloat16]
NUM_GEN_SEQS = [7] # Arbitrary values for testing NUM_GEN_SEQS = [7] # Arbitrary values for testing
NUM_PREFILL_SEQS = [3] # Arbitrary values for testing NUM_PREFILL_SEQS = [3] # Arbitrary values for testing
NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing
...@@ -114,7 +116,8 @@ def ref_single_query_cached_kv_attention( ...@@ -114,7 +116,8 @@ def ref_single_query_cached_kv_attention(
@pytest.mark.parametrize( @pytest.mark.parametrize(
"version", ["v1", "v2"] if not is_hip() else ["v1", "v2", "rocm"]) "version",
["v1", "v2"] if not current_platform.is_rocm() else ["v1", "v2", "rocm"])
@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS) @pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS)
@pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("head_size", HEAD_SIZES)
...@@ -317,8 +320,8 @@ def test_paged_attention( ...@@ -317,8 +320,8 @@ def test_paged_attention(
# NOTE(woosuk): Due to the kernel-level differences in the two # NOTE(woosuk): Due to the kernel-level differences in the two
# implementations, there is a small numerical difference in the two # implementations, there is a small numerical difference in the two
# outputs. Thus, we use a relaxed tolerance for the test. # outputs. Thus, we use a relaxed tolerance for the test.
atol = get_default_atol(output) if is_hip() else 1e-3 atol = get_default_atol(output) if current_platform.is_rocm() else 1e-3
rtol = get_default_rtol(output) if is_hip() else 1e-5 rtol = get_default_rtol(output) if current_platform.is_rocm() else 1e-5
# NOTE(zhaoyang): FP8 KV Cache will introduce quantization error, # NOTE(zhaoyang): FP8 KV Cache will introduce quantization error,
# so we use a relaxed tolerance for the test. # so we use a relaxed tolerance for the test.
...@@ -368,7 +371,7 @@ def ref_multi_query_kv_attention( ...@@ -368,7 +371,7 @@ def ref_multi_query_kv_attention(
@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.skipif(is_hip(), @pytest.mark.skipif(current_platform.is_rocm(),
reason="Xformers backend is not supported on ROCm.") reason="Xformers backend is not supported on ROCm.")
@torch.inference_mode() @torch.inference_mode()
def test_multi_query_kv_attention( def test_multi_query_kv_attention(
...@@ -425,6 +428,6 @@ def test_multi_query_kv_attention( ...@@ -425,6 +428,6 @@ def test_multi_query_kv_attention(
scale, scale,
dtype, dtype,
) )
atol = get_default_atol(output) if is_hip() else 1e-3 atol = get_default_atol(output) if current_platform.is_rocm() else 1e-3
rtol = get_default_rtol(output) if is_hip() else 1e-5 rtol = get_default_rtol(output) if current_platform.is_rocm() else 1e-5
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol) torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol)
...@@ -25,7 +25,8 @@ def test_env(name: str, device: str, monkeypatch): ...@@ -25,7 +25,8 @@ def test_env(name: str, device: str, monkeypatch):
False) False)
assert backend.name == "TORCH_SDPA" assert backend.name == "TORCH_SDPA"
elif device == "hip": elif device == "hip":
with patch("vllm.attention.selector.is_hip", return_value=True): with patch("vllm.attention.selector.current_platform.is_rocm",
return_value=True):
backend = which_attn_to_use(16, torch.float16, torch.float16, 16, backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
False) False)
assert backend.name == "ROCM_FLASH" assert backend.name == "ROCM_FLASH"
......
...@@ -7,7 +7,8 @@ import torch ...@@ -7,7 +7,8 @@ import torch
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.attention.ops.blocksparse_attention.interface import ( from vllm.attention.ops.blocksparse_attention.interface import (
LocalStridedBlockSparseAttn) LocalStridedBlockSparseAttn)
from vllm.utils import get_max_shared_memory_bytes, is_hip, seed_everything from vllm.platforms import current_platform
from vllm.utils import get_max_shared_memory_bytes, seed_everything
from .allclose_default import get_default_atol, get_default_rtol from .allclose_default import get_default_atol, get_default_rtol
...@@ -316,8 +317,8 @@ def test_paged_attention( ...@@ -316,8 +317,8 @@ def test_paged_attention(
# NOTE(woosuk): Due to the kernel-level differences in the two # NOTE(woosuk): Due to the kernel-level differences in the two
# implementations, there is a small numerical difference in the two # implementations, there is a small numerical difference in the two
# outputs. Thus, we use a relaxed tolerance for the test. # outputs. Thus, we use a relaxed tolerance for the test.
atol = get_default_atol(output) if is_hip() else 1e-3 atol = get_default_atol(output) if current_platform.is_rocm() else 1e-3
rtol = get_default_rtol(output) if is_hip() else 1e-5 rtol = get_default_rtol(output) if current_platform.is_rocm() else 1e-5
# NOTE(zhaoyang): FP8 KV Cache will introduce quantization error, # NOTE(zhaoyang): FP8 KV Cache will introduce quantization error,
# so we use a relaxed tolerance for the test. # so we use a relaxed tolerance for the test.
......
...@@ -18,7 +18,7 @@ from vllm.attention import (Attention, AttentionBackend, AttentionMetadata, ...@@ -18,7 +18,7 @@ from vllm.attention import (Attention, AttentionBackend, AttentionMetadata,
from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_ROCM_HIP from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_ROCM_HIP
from vllm.attention.selector import (_Backend, from vllm.attention.selector import (_Backend,
global_force_attn_backend_context_manager) global_force_attn_backend_context_manager)
from vllm.utils import is_hip from vllm.platforms import current_platform
# List of support backends for encoder/decoder models # List of support backends for encoder/decoder models
LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS] LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS]
...@@ -82,7 +82,7 @@ class TestResources(NamedTuple): ...@@ -82,7 +82,7 @@ class TestResources(NamedTuple):
will leverage attn_backend for the purpose of will leverage attn_backend for the purpose of
constructing backend-compatible attention constructing backend-compatible attention
metadata instances metadata instances
Attributes: Attributes:
* scale: 1/sqrt(d) scale factor for attn * scale: 1/sqrt(d) scale factor for attn
...@@ -105,10 +105,10 @@ def _make_test_resources(test_pt: TestPoint, ) -> TestResources: ...@@ -105,10 +105,10 @@ def _make_test_resources(test_pt: TestPoint, ) -> TestResources:
Build key components for performing encoder/decoder attention test. Build key components for performing encoder/decoder attention test.
Note that Note that
(1) The Attention instance constructed here, automatically selects (1) The Attention instance constructed here, automatically selects
an attention backend class based on platform info & a set of canned an attention backend class based on platform info & a set of canned
heuristics, so heuristics, so
(2) The attention backend instance constructed here is thus *not (2) The attention backend instance constructed here is thus *not
the same backend instance* used by attn, but rather it is the same backend instance* used by attn, but rather it is
intended to be a *different instance* of the *same backend class*; intended to be a *different instance* of the *same backend class*;
therefore, therefore,
...@@ -156,7 +156,7 @@ def _encoder_attn_setup( ...@@ -156,7 +156,7 @@ def _encoder_attn_setup(
''' '''
Set up test vectors & data structures for encoder attention test. Set up test vectors & data structures for encoder attention test.
A triplet of synthetic query/key/value tensors are constructed. A triplet of synthetic query/key/value tensors are constructed.
Given this is an encoder attention test, the key & value Given this is an encoder attention test, the key & value
sequences will have the same length as the corresponding queries. sequences will have the same length as the corresponding queries.
...@@ -169,14 +169,14 @@ def _encoder_attn_setup( ...@@ -169,14 +169,14 @@ def _encoder_attn_setup(
Arguments: Arguments:
* test_pt: TestPoint data structure; this function relies on the * test_pt: TestPoint data structure; this function relies on the
following fields: batch_size, num_heads, head_size, following fields: batch_size, num_heads, head_size,
block_size, max_q_seq_len block_size, max_q_seq_len
* test_rsrcs: TestResources data structure; this function relies on the * test_rsrcs: TestResources data structure; this function relies on the
scale field scale field
Returns: Returns:
* PhaseTestParameters data structure comprising (1) packed query/key/value * PhaseTestParameters data structure comprising (1) packed query/key/value
tensors, (2) the ideal output of attention computed using a naive tensors, (2) the ideal output of attention computed using a naive
implementation, and (3) KVCache field set to None implementation, and (3) KVCache field set to None
...@@ -265,7 +265,7 @@ def _decoder_attn_setup( ...@@ -265,7 +265,7 @@ def _decoder_attn_setup(
Arguments: Arguments:
* test_pt: TestPoint data structure; this function relies on the * test_pt: TestPoint data structure; this function relies on the
following fields: batch_size, num_heads, head_size, following fields: batch_size, num_heads, head_size,
block_size, max_q_seq_len block_size, max_q_seq_len
* test_rsrcs: TestResources data structure; this function relies on the * test_rsrcs: TestResources data structure; this function relies on the
scale field scale field
...@@ -275,14 +275,14 @@ def _decoder_attn_setup( ...@@ -275,14 +275,14 @@ def _decoder_attn_setup(
* qkv: Unpacked (batch_size x padded_seq_len x num_heads x * qkv: Unpacked (batch_size x padded_seq_len x num_heads x
head_size) query/key/value tensors head_size) query/key/value tensors
* Prefill-phase decoder self-attention PhaseTestParameters data structure, * Prefill-phase decoder self-attention PhaseTestParameters data structure,
including (1) packed (number_of_tokens x num_heads x head_size) including (1) packed (number_of_tokens x num_heads x head_size)
query/key/value tensors along with (2) ideal attention output query/key/value tensors along with (2) ideal attention output
computed using a naive implementation, and (3) memory-mapping data computed using a naive implementation, and (3) memory-mapping data
structures appropriate for prefill phase. structures appropriate for prefill phase.
* Decode-phase decoder self-attention PhaseTestParameters data structure, * Decode-phase decoder self-attention PhaseTestParameters data structure,
including (1) packed (number_of_tokens x num_heads x head_size) including (1) packed (number_of_tokens x num_heads x head_size)
query/key/value tensors along with (2) ideal attention output query/key/value tensors along with (2) ideal attention output
computed using a naive implementation, and (3) memory-mapping data computed using a naive implementation, and (3) memory-mapping data
structures appropriate for decode phase. structures appropriate for decode phase.
* max_block_idx: max physical address in decoder self-attention block-table * max_block_idx: max physical address in decoder self-attention block-table
(intended to be used as the base address for the encoder/ (intended to be used as the base address for the encoder/
...@@ -436,12 +436,12 @@ def _enc_dec_cross_attn_setup_reuses_query( ...@@ -436,12 +436,12 @@ def _enc_dec_cross_attn_setup_reuses_query(
This function also constructs the cross-attention KV cache memory mapping This function also constructs the cross-attention KV cache memory mapping
(slot mapping and block table), ensuring that the block table starts at (slot mapping and block table), ensuring that the block table starts at
block_base_addr. block_base_addr.
Arguments: Arguments:
* decoder_qkv: pre-existing unpacked (batch_size x padded_seq_len x * decoder_qkv: pre-existing unpacked (batch_size x padded_seq_len x
num_heads x head_size) decoder self-attention inputs; num_heads x head_size) decoder self-attention inputs;
this function relies on the query and q_seq_lens this function relies on the query and q_seq_lens
fields fields
* encoder_test_params: PhaseTestParameters data structure which was * encoder_test_params: PhaseTestParameters data structure which was
...@@ -452,7 +452,7 @@ def _enc_dec_cross_attn_setup_reuses_query( ...@@ -452,7 +452,7 @@ def _enc_dec_cross_attn_setup_reuses_query(
self-attention; all fields self-attention; all fields
including KV cache required including KV cache required
* test_pt: TestPoint data structure; this function relies on the * test_pt: TestPoint data structure; this function relies on the
following fields: batch_size, num_heads, head_size, following fields: batch_size, num_heads, head_size,
block_size, max_q_seq_len block_size, max_q_seq_len
* test_rsrcs: TestResources data structure; this function relies on the * test_rsrcs: TestResources data structure; this function relies on the
scale field scale field
...@@ -460,16 +460,16 @@ def _enc_dec_cross_attn_setup_reuses_query( ...@@ -460,16 +460,16 @@ def _enc_dec_cross_attn_setup_reuses_query(
Returns: Returns:
* Prefill-phase encoder/decoder cross-attention PhaseTestParameters data * Prefill-phase encoder/decoder cross-attention PhaseTestParameters data
structure, including (1) packed structure, including (1) packed
(number_of_tokens x num_heads x head_size) query/key/value tensors (number_of_tokens x num_heads x head_size) query/key/value tensors
along with (2) ideal attention output computed using a along with (2) ideal attention output computed using a
naive implementation, and (3) memory-mapping data structures appropriate naive implementation, and (3) memory-mapping data structures appropriate
for prefill phase. for prefill phase.
* Decode-phase encoder/decoder cross-attention PhaseTestParameters data * Decode-phase encoder/decoder cross-attention PhaseTestParameters data
structure, including (1) packed structure, including (1) packed
(number_of_tokens x num_heads x head_size) query/key/value tensors (number_of_tokens x num_heads x head_size) query/key/value tensors
along with (2) ideal attention output computed using a along with (2) ideal attention output computed using a
naive implementation, and (3) memory-mapping data structures appropriate naive implementation, and (3) memory-mapping data structures appropriate
for decode phase. for decode phase.
''' '''
...@@ -596,7 +596,7 @@ def _run_encoder_attention_test( ...@@ -596,7 +596,7 @@ def _run_encoder_attention_test(
''' '''
Run encoder attention. Run encoder attention.
attn.forward() is passed attn_type=AttentionType.ENCODER in order attn.forward() is passed attn_type=AttentionType.ENCODER in order
to configure the kernel invocation for encoder attention to configure the kernel invocation for encoder attention
Requires attn_metadata.num_decode_tokens == 0 Requires attn_metadata.num_decode_tokens == 0
...@@ -607,7 +607,7 @@ def _run_encoder_attention_test( ...@@ -607,7 +607,7 @@ def _run_encoder_attention_test(
* attn: Attention wrapper instance * attn: Attention wrapper instance
* encoder_test_params: encoder PhaseTestParameters data structure; * encoder_test_params: encoder PhaseTestParameters data structure;
this function relies on the packed this function relies on the packed
(number_of_tokens x num_heads x head_size) (number_of_tokens x num_heads x head_size)
query/key/value fields query/key/value fields
* attn_metadata: attention metadata for encoder/decoder-self attention * attn_metadata: attention metadata for encoder/decoder-self attention
...@@ -646,7 +646,7 @@ def _run_decoder_self_attention_test( ...@@ -646,7 +646,7 @@ def _run_decoder_self_attention_test(
and attn (Attention wrapper instance) fields and attn (Attention wrapper instance) fields
* decoder_test_params: decoder PhaseTestParameters data structure; * decoder_test_params: decoder PhaseTestParameters data structure;
this function relies on the packed this function relies on the packed
(number_of_tokens x num_heads x head_size) (number_of_tokens x num_heads x head_size)
query/key/value fields query/key/value fields
* attn_metadata: attention metadata for decoder-self attention * attn_metadata: attention metadata for decoder-self attention
(contains KV cache memory-mapping) (contains KV cache memory-mapping)
...@@ -694,11 +694,11 @@ def _run_encoder_decoder_cross_attention_test( ...@@ -694,11 +694,11 @@ def _run_encoder_decoder_cross_attention_test(
and attn (Attention wrapper instance) fields and attn (Attention wrapper instance) fields
* decoder_test_params: decoder PhaseTestParameters data structure; * decoder_test_params: decoder PhaseTestParameters data structure;
this function relies on the packed this function relies on the packed
(number_of_tokens x num_heads x head_size) (number_of_tokens x num_heads x head_size)
query field query field
* cross_test_params: encoder/decoder PhaseTestParameters data structure; * cross_test_params: encoder/decoder PhaseTestParameters data structure;
this function relies on the packed this function relies on the packed
(number_of_tokens x num_heads x head_size) (number_of_tokens x num_heads x head_size)
key/value fields key/value fields
* attn_metadata: attention metadata for encoder/decoder-self attention * attn_metadata: attention metadata for encoder/decoder-self attention
...@@ -726,7 +726,8 @@ def _run_encoder_decoder_cross_attention_test( ...@@ -726,7 +726,8 @@ def _run_encoder_decoder_cross_attention_test(
attn_type=attn_type) attn_type=attn_type)
@pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP) @pytest.mark.skipif(current_platform.is_rocm(),
reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP)
@pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS) @pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS)
...@@ -755,7 +756,8 @@ def test_encoder_only( ...@@ -755,7 +756,8 @@ def test_encoder_only(
No KV cache is required for encoder-only attention. No KV cache is required for encoder-only attention.
Note on ROCm/HIP: currently encoder/decoder models are not supported on Note on ROCm/HIP: currently encoder/decoder models are not supported on
AMD GPUs, therefore this test simply is skipped if is_hip(). AMD GPUs, therefore this test simply is skipped if
current_platform.is_rocm().
This test globally forces an override of the usual backend This test globally forces an override of the usual backend
auto-selection process, forcing the specific backend-under-test auto-selection process, forcing the specific backend-under-test
...@@ -811,7 +813,8 @@ def test_encoder_only( ...@@ -811,7 +813,8 @@ def test_encoder_only(
assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out) assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out)
@pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP) @pytest.mark.skipif(current_platform.is_rocm(),
reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP)
@pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS) @pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS)
...@@ -837,14 +840,14 @@ def test_e2e_enc_dec_attn( ...@@ -837,14 +840,14 @@ def test_e2e_enc_dec_attn(
attributes for prefill-phase, and (2) an analogous attention metadata attributes for prefill-phase, and (2) an analogous attention metadata
structure but for decode-phase structure but for decode-phase
* Test attention steps in the following order * Test attention steps in the following order
* Encoder attention * Encoder attention
* Prefill self-attention * Prefill self-attention
* Prefill cross-attention * Prefill cross-attention
* Decode self-attention * Decode self-attention
* Decode cross-attention * Decode cross-attention
* Besides being reflective of realistic use-cases, this order would * Besides being reflective of realistic use-cases, this order would
exacerbate any accidental overlap in the self-/cross-attention exacerbate any accidental overlap in the self-/cross-attention
block tables, which one hopes to avoid block tables, which one hopes to avoid
...@@ -864,10 +867,11 @@ def test_e2e_enc_dec_attn( ...@@ -864,10 +867,11 @@ def test_e2e_enc_dec_attn(
to be utilized. to be utilized.
Note on ROCm/HIP: currently encoder/decoder models are not supported on Note on ROCm/HIP: currently encoder/decoder models are not supported on
AMD GPUs, therefore this test simply is skipped if is_hip(). AMD GPUs, therefore this test simply is skipped if
current_platform.is_rocm().
Note on metadata: there is a single attention metadata structure shared by Note on metadata: there is a single attention metadata structure shared by
all prefill-phase attention operations (encoder, decoder, enc/dec cross), all prefill-phase attention operations (encoder, decoder, enc/dec cross),
and a single one shared by all decode-phase attention operations and a single one shared by all decode-phase attention operations
(decoder & enc/dec cross.) This is intended to reflect the behavior (decoder & enc/dec cross.) This is intended to reflect the behavior
of EncoderDecoderModelRunner, which constructs a single attention metadata of EncoderDecoderModelRunner, which constructs a single attention metadata
......
...@@ -18,8 +18,9 @@ from vllm.model_executor.layers.fused_moe.fused_moe import ( ...@@ -18,8 +18,9 @@ from vllm.model_executor.layers.fused_moe.fused_moe import (
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
marlin_quantize) marlin_quantize)
from vllm.model_executor.models.mixtral import MixtralMoE from vllm.model_executor.models.mixtral import MixtralMoE
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types from vllm.scalar_type import scalar_types
from vllm.utils import is_hip, seed_everything from vllm.utils import seed_everything
@pytest.mark.parametrize("m", [1024 * 128, 512, 222, 33, 1]) @pytest.mark.parametrize("m", [1024 * 128, 512, 222, 33, 1])
...@@ -103,7 +104,7 @@ def test_mixtral_moe(dtype: torch.dtype): ...@@ -103,7 +104,7 @@ def test_mixtral_moe(dtype: torch.dtype):
@pytest.mark.parametrize("act_order", [True, False]) @pytest.mark.parametrize("act_order", [True, False])
@pytest.mark.parametrize("num_bits", [4, 8]) @pytest.mark.parametrize("num_bits", [4, 8])
@pytest.mark.parametrize("is_k_full", [True, False]) @pytest.mark.parametrize("is_k_full", [True, False])
@pytest.mark.skipif(is_hip(), reason="Skip for rocm") @pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
def test_fused_marlin_moe( def test_fused_marlin_moe(
m: int, m: int,
n: int, n: int,
...@@ -256,7 +257,7 @@ def test_fused_marlin_moe( ...@@ -256,7 +257,7 @@ def test_fused_marlin_moe(
@pytest.mark.parametrize("act_order", [True, False]) @pytest.mark.parametrize("act_order", [True, False])
@pytest.mark.parametrize("num_bits", [4, 8]) @pytest.mark.parametrize("num_bits", [4, 8])
@pytest.mark.parametrize("is_k_full", [True, False]) @pytest.mark.parametrize("is_k_full", [True, False])
@pytest.mark.skipif(is_hip(), reason="Skip for rocm") @pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
def test_single_marlin_moe_multiply( def test_single_marlin_moe_multiply(
m: int, m: int,
n: int, n: int,
......
...@@ -4,7 +4,7 @@ import pytest ...@@ -4,7 +4,7 @@ import pytest
import vllm import vllm
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.utils import is_hip from vllm.platforms import current_platform
MODEL_PATH = "google/gemma-7b" MODEL_PATH = "google/gemma-7b"
...@@ -31,7 +31,8 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: ...@@ -31,7 +31,8 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
return generated_texts return generated_texts
@pytest.mark.xfail(is_hip(), reason="There can be output mismatch on ROCm") @pytest.mark.xfail(current_platform.is_rocm(),
reason="There can be output mismatch on ROCm")
def test_gemma_lora(gemma_lora_files): def test_gemma_lora(gemma_lora_files):
llm = vllm.LLM(MODEL_PATH, llm = vllm.LLM(MODEL_PATH,
max_model_len=1024, max_model_len=1024,
......
...@@ -8,7 +8,7 @@ import pytest ...@@ -8,7 +8,7 @@ import pytest
import vllm import vllm
from vllm.distributed import cleanup_dist_env_and_memory from vllm.distributed import cleanup_dist_env_and_memory
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.utils import is_hip from vllm.platforms import current_platform
@dataclass @dataclass
...@@ -19,7 +19,7 @@ class ModelWithQuantization: ...@@ -19,7 +19,7 @@ class ModelWithQuantization:
MODELS: List[ModelWithQuantization] MODELS: List[ModelWithQuantization]
#AWQ quantization is currently not supported in ROCm. #AWQ quantization is currently not supported in ROCm.
if is_hip(): if current_platform.is_rocm():
MODELS = [ MODELS = [
ModelWithQuantization( ModelWithQuantization(
model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ", model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ",
......
...@@ -6,8 +6,9 @@ from transformers import (AutoConfig, AutoModelForVision2Seq, AutoTokenizer, ...@@ -6,8 +6,9 @@ from transformers import (AutoConfig, AutoModelForVision2Seq, AutoTokenizer,
BatchEncoding) BatchEncoding)
from vllm.multimodal.utils import rescale_image_size from vllm.multimodal.utils import rescale_image_size
from vllm.platforms import current_platform
from vllm.sequence import SampleLogprobs from vllm.sequence import SampleLogprobs
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, is_hip from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
from ....conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets from ....conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets
from ...utils import check_logprobs_close from ...utils import check_logprobs_close
...@@ -24,7 +25,7 @@ models = ["google/paligemma-3b-mix-224"] ...@@ -24,7 +25,7 @@ models = ["google/paligemma-3b-mix-224"]
# ROCm Triton FA can run into compilation issues with these models due to, # ROCm Triton FA can run into compilation issues with these models due to,
# excessive use of shared memory. Use other backends in the meantime. # excessive use of shared memory. Use other backends in the meantime.
# FIXME (mattwong, gshtrasb, hongxiayan) # FIXME (mattwong, gshtrasb, hongxiayan)
if is_hip(): if current_platform.is_rocm():
os.environ["VLLM_USE_TRITON_FLASH_ATTN"] = "0" os.environ["VLLM_USE_TRITON_FLASH_ATTN"] = "0"
...@@ -70,7 +71,7 @@ def run_test( ...@@ -70,7 +71,7 @@ def run_test(
All the image fixtures for the test are from IMAGE_ASSETS. All the image fixtures for the test are from IMAGE_ASSETS.
For huggingface runner, we provide the PIL images as input. For huggingface runner, we provide the PIL images as input.
For vllm runner, we provide MultiModalDataDict objects For vllm runner, we provide MultiModalDataDict objects
and corresponding MultiModalConfig as input. and corresponding MultiModalConfig as input.
Note, the text input is also adjusted to abide by vllm contract. Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf. The text output is sanitized to be able to compare with hf.
...@@ -151,7 +152,7 @@ def run_test( ...@@ -151,7 +152,7 @@ def run_test(
pytest.param( pytest.param(
"float", "float",
marks=pytest.mark.skipif( marks=pytest.mark.skipif(
is_hip(), current_platform.is_rocm(),
reason= reason=
"ROCm FA does not yet fully support 32-bit precision on PaliGemma") "ROCm FA does not yet fully support 32-bit precision on PaliGemma")
), "half" ), "half"
......
...@@ -12,7 +12,6 @@ from vllm.multimodal import MultiModalRegistry ...@@ -12,7 +12,6 @@ from vllm.multimodal import MultiModalRegistry
from vllm.multimodal.utils import rescale_image_size from vllm.multimodal.utils import rescale_image_size
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.sequence import SampleLogprobs from vllm.sequence import SampleLogprobs
from vllm.utils import is_hip
from ....conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner, from ....conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner,
_ImageAssets) _ImageAssets)
...@@ -56,7 +55,7 @@ if current_platform.is_cpu(): ...@@ -56,7 +55,7 @@ if current_platform.is_cpu():
# ROCm Triton FA can run into shared memory issues with these models, # ROCm Triton FA can run into shared memory issues with these models,
# use other backends in the meantime # use other backends in the meantime
# FIXME (mattwong, gshtrasb, hongxiayan) # FIXME (mattwong, gshtrasb, hongxiayan)
if is_hip(): if current_platform.is_rocm():
os.environ["VLLM_USE_TRITON_FLASH_ATTN"] = "0" os.environ["VLLM_USE_TRITON_FLASH_ATTN"] = "0"
......
...@@ -5,7 +5,7 @@ tensor parallelism. ...@@ -5,7 +5,7 @@ tensor parallelism.
import pytest import pytest
import torch import torch
from vllm.utils import is_hip from vllm.platforms import current_platform
from .conftest import run_equality_correctness_test_tp from .conftest import run_equality_correctness_test_tp
...@@ -51,7 +51,7 @@ def test_target_model_tp_gt_1(common_llm_kwargs, per_test_common_llm_kwargs, ...@@ -51,7 +51,7 @@ def test_target_model_tp_gt_1(common_llm_kwargs, per_test_common_llm_kwargs,
batch_size: int, output_len: int, seed: int): batch_size: int, output_len: int, seed: int):
"""Verify greedy equality when tensor parallelism is used. """Verify greedy equality when tensor parallelism is used.
""" """
if is_hip(): if current_platform.is_rocm():
pytest.skip("hip is not well-supported yet") pytest.skip("hip is not well-supported yet")
run_equality_correctness_test_tp("JackFram/llama-68m", run_equality_correctness_test_tp("JackFram/llama-68m",
common_llm_kwargs, common_llm_kwargs,
......
...@@ -26,7 +26,7 @@ from vllm.model_executor.model_loader.loader import get_model_loader ...@@ -26,7 +26,7 @@ from vllm.model_executor.model_loader.loader import get_model_loader
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.utils import (FlexibleArgumentParser, GB_bytes, from vllm.utils import (FlexibleArgumentParser, GB_bytes,
cuda_device_count_stateless, get_open_port, is_hip) cuda_device_count_stateless, get_open_port)
if current_platform.is_rocm(): if current_platform.is_rocm():
from amdsmi import (amdsmi_get_gpu_vram_usage, from amdsmi import (amdsmi_get_gpu_vram_usage,
...@@ -487,7 +487,7 @@ def wait_for_gpu_memory_to_clear(devices: List[int], ...@@ -487,7 +487,7 @@ def wait_for_gpu_memory_to_clear(devices: List[int],
output: Dict[int, str] = {} output: Dict[int, str] = {}
output_raw: Dict[int, float] = {} output_raw: Dict[int, float] = {}
for device in devices: for device in devices:
if is_hip(): if current_platform.is_rocm():
dev_handle = amdsmi_get_processor_handles()[device] dev_handle = amdsmi_get_processor_handles()[device]
mem_info = amdsmi_get_gpu_vram_usage(dev_handle) mem_info = amdsmi_get_gpu_vram_usage(dev_handle)
gb_used = mem_info["vram_used"] / 2**10 gb_used = mem_info["vram_used"] / 2**10
......
...@@ -659,11 +659,11 @@ def scaled_fp8_quant( ...@@ -659,11 +659,11 @@ def scaled_fp8_quant(
Args: Args:
input: The input tensor to be quantized to FP8 input: The input tensor to be quantized to FP8
scale: Optional scaling factor for the FP8 quantization scale: Optional scaling factor for the FP8 quantization
scale_ub: Optional upper bound for scaling factor in dynamic scale_ub: Optional upper bound for scaling factor in dynamic
per token case per token case
num_token_padding: If specified, pad the first dimension num_token_padding: If specified, pad the first dimension
of the output to at least this value. of the output to at least this value.
use_per_token_if_dynamic: Whether to do per_tensor or per_token use_per_token_if_dynamic: Whether to do per_tensor or per_token
in the dynamic quantization case. in the dynamic quantization case.
Returns: Returns:
...@@ -674,8 +674,8 @@ def scaled_fp8_quant( ...@@ -674,8 +674,8 @@ def scaled_fp8_quant(
assert (input.ndim == 2) assert (input.ndim == 2)
shape: Union[Tuple[int, int], torch.Size] = input.shape shape: Union[Tuple[int, int], torch.Size] = input.shape
# For rocm, the output fp8 dtype is torch.float_e3m3fnuz # For rocm, the output fp8 dtype is torch.float_e3m3fnuz
out_dtype: torch.dtype = torch.float8_e4m3fnuz if vllm.utils.is_hip() \ out_dtype: torch.dtype = torch.float8_e4m3fnuz \
else torch.float8_e4m3fn if current_platform.is_rocm() else torch.float8_e4m3fn
if num_token_padding: if num_token_padding:
shape = (max(num_token_padding, input.shape[0]), shape[1]) shape = (max(num_token_padding, input.shape[0]), shape[1])
output = torch.empty(shape, device=input.device, dtype=out_dtype) output = torch.empty(shape, device=input.device, dtype=out_dtype)
......
...@@ -3,7 +3,6 @@ import math ...@@ -3,7 +3,6 @@ import math
import torch import torch
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import is_hip
from .utils import (dense_to_crow_col, get_head_sliding_step, from .utils import (dense_to_crow_col, get_head_sliding_step,
get_sparse_attn_mask) get_sparse_attn_mask)
...@@ -32,8 +31,9 @@ class LocalStridedBlockSparseAttn(torch.nn.Module): ...@@ -32,8 +31,9 @@ class LocalStridedBlockSparseAttn(torch.nn.Module):
): ):
super().__init__() super().__init__()
if use_spda is None: if use_spda is None:
use_spda = is_hip() or current_platform.is_cpu() or not \ use_spda = current_platform.is_rocm() or \
IS_COMPUTE_8_OR_ABOVE current_platform.is_cpu() or not \
IS_COMPUTE_8_OR_ABOVE
device = device or (torch.cuda.current_device() device = device or (torch.cuda.current_device()
if current_platform.is_cuda_alike() else "cpu") if current_platform.is_cuda_alike() else "cpu")
device = torch.device(device) device = torch.device(device)
......
...@@ -10,7 +10,7 @@ import vllm.envs as envs ...@@ -10,7 +10,7 @@ import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.backends.abstract import AttentionBackend
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import STR_BACKEND_ENV_VAR, is_hip from vllm.utils import STR_BACKEND_ENV_VAR
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -208,7 +208,7 @@ def which_attn_to_use( ...@@ -208,7 +208,7 @@ def which_attn_to_use(
logger.info("Cannot use %s backend on TPU.", selected_backend) logger.info("Cannot use %s backend on TPU.", selected_backend)
return _Backend.PALLAS return _Backend.PALLAS
if is_hip(): if current_platform.is_rocm():
# AMD GPUs. # AMD GPUs.
selected_backend = (_Backend.ROCM_FLASH if selected_backend selected_backend = (_Backend.ROCM_FLASH if selected_backend
== _Backend.FLASH_ATTN else selected_backend) == _Backend.FLASH_ATTN else selected_backend)
......
...@@ -17,7 +17,7 @@ from vllm.transformers_utils.config import (ConfigFormat, get_config, ...@@ -17,7 +17,7 @@ from vllm.transformers_utils.config import (ConfigFormat, get_config,
get_hf_image_processor_config, get_hf_image_processor_config,
get_hf_text_config) get_hf_text_config)
from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory, from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory,
is_hip, print_warning_once) print_warning_once)
if TYPE_CHECKING: if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup from ray.util.placement_group import PlacementGroup
...@@ -43,7 +43,7 @@ class ModelConfig: ...@@ -43,7 +43,7 @@ class ModelConfig:
Args: Args:
model: Name or path of the huggingface model to use. model: Name or path of the huggingface model to use.
It is also used as the content for `model_name` tag in metrics It is also used as the content for `model_name` tag in metrics
output when `served_model_name` is not specified. output when `served_model_name` is not specified.
task: The task to use the model for. Each vLLM instance only supports task: The task to use the model for. Each vLLM instance only supports
one task, even if the same model can be used for multiple tasks. one task, even if the same model can be used for multiple tasks.
...@@ -99,15 +99,15 @@ class ModelConfig: ...@@ -99,15 +99,15 @@ class ModelConfig:
skip_tokenizer_init: If true, skip initialization of tokenizer and skip_tokenizer_init: If true, skip initialization of tokenizer and
detokenizer. detokenizer.
served_model_name: The model name used in metrics tag `model_name`, served_model_name: The model name used in metrics tag `model_name`,
matches the model name exposed via the APIs. If multiple model matches the model name exposed via the APIs. If multiple model
names provided, the first name will be used. If not specified, names provided, the first name will be used. If not specified,
the model name will be the same as `model`. the model name will be the same as `model`.
limit_mm_per_prompt: Maximum number of data instances per modality limit_mm_per_prompt: Maximum number of data instances per modality
per prompt. Only applicable for multimodal models. per prompt. Only applicable for multimodal models.
override_neuron_config: Initialize non default neuron config or override_neuron_config: Initialize non default neuron config or
override default neuron config that are specific to Neuron devices, override default neuron config that are specific to Neuron devices,
this argument will be used to configure the neuron config that this argument will be used to configure the neuron config that
can not be gathered from the vllm arguments. can not be gathered from the vllm arguments.
config_format: The config format which shall be loaded. config_format: The config format which shall be loaded.
Defaults to 'auto' which defaults to 'hf'. Defaults to 'auto' which defaults to 'hf'.
mm_processor_kwargs: Arguments to be forwarded to the model's processor mm_processor_kwargs: Arguments to be forwarded to the model's processor
...@@ -350,7 +350,7 @@ class ModelConfig: ...@@ -350,7 +350,7 @@ class ModelConfig:
raise ValueError( raise ValueError(
f"Unknown quantization method: {self.quantization}. Must " f"Unknown quantization method: {self.quantization}. Must "
f"be one of {supported_quantization}.") f"be one of {supported_quantization}.")
if is_hip( if current_platform.is_rocm(
) and self.quantization not in rocm_supported_quantization: ) and self.quantization not in rocm_supported_quantization:
raise ValueError( raise ValueError(
f"{self.quantization} quantization is currently not " f"{self.quantization} quantization is currently not "
...@@ -365,7 +365,7 @@ class ModelConfig: ...@@ -365,7 +365,7 @@ class ModelConfig:
"%s quantization is not fully " "%s quantization is not fully "
"optimized yet. The speed can be slower than " "optimized yet. The speed can be slower than "
"non-quantized models.", self.quantization) "non-quantized models.", self.quantization)
if (self.quantization == "awq" and is_hip() if (self.quantization == "awq" and current_platform.is_rocm()
and not envs.VLLM_USE_TRITON_AWQ): and not envs.VLLM_USE_TRITON_AWQ):
logger.warning( logger.warning(
"Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ" "Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ"
...@@ -385,7 +385,7 @@ class ModelConfig: ...@@ -385,7 +385,7 @@ class ModelConfig:
def _verify_bnb_config(self) -> None: def _verify_bnb_config(self) -> None:
""" """
The current version of bitsandbytes (0.44.0) with 8-bit models does not The current version of bitsandbytes (0.44.0) with 8-bit models does not
yet support CUDA graph. yet support CUDA graph.
""" """
is_bitsandbytes = self.quantization == "bitsandbytes" is_bitsandbytes = self.quantization == "bitsandbytes"
...@@ -810,7 +810,7 @@ class LoadConfig: ...@@ -810,7 +810,7 @@ class LoadConfig:
fast weight loading. fast weight loading.
"bitsandbytes" will load nf4 type weights. "bitsandbytes" will load nf4 type weights.
ignore_patterns: The list of patterns to ignore when loading the model. ignore_patterns: The list of patterns to ignore when loading the model.
Default to "original/**/*" to avoid repeated loading of llama's Default to "original/**/*" to avoid repeated loading of llama's
checkpoints. checkpoints.
""" """
...@@ -843,7 +843,8 @@ class LoadConfig: ...@@ -843,7 +843,8 @@ class LoadConfig:
self.load_format = LoadFormat(load_format) self.load_format = LoadFormat(load_format)
rocm_not_supported_load_format: List[str] = [] rocm_not_supported_load_format: List[str] = []
if is_hip() and load_format in rocm_not_supported_load_format: if current_platform.is_rocm(
) and load_format in rocm_not_supported_load_format:
rocm_supported_load_format = [ rocm_supported_load_format = [
f for f in LoadFormat.__members__ f for f in LoadFormat.__members__
if (f not in rocm_not_supported_load_format) if (f not in rocm_not_supported_load_format)
...@@ -967,7 +968,7 @@ class ParallelConfig: ...@@ -967,7 +968,7 @@ class ParallelConfig:
if self.use_ray: if self.use_ray:
from vllm.executor import ray_utils from vllm.executor import ray_utils
ray_utils.assert_ray_available() ray_utils.assert_ray_available()
if is_hip(): if current_platform.is_rocm():
self.disable_custom_all_reduce = True self.disable_custom_all_reduce = True
logger.info( logger.info(
"Disabled the custom all-reduce kernel because it is not " "Disabled the custom all-reduce kernel because it is not "
...@@ -996,7 +997,7 @@ class SchedulerConfig: ...@@ -996,7 +997,7 @@ class SchedulerConfig:
prompt latency) before scheduling next prompt. prompt latency) before scheduling next prompt.
enable_chunked_prefill: If True, prefill requests can be chunked based enable_chunked_prefill: If True, prefill requests can be chunked based
on the remaining max_num_batched_tokens. on the remaining max_num_batched_tokens.
preemption_mode: Whether to perform preemption by swapping or preemption_mode: Whether to perform preemption by swapping or
recomputation. If not specified, we determine the mode as follows: recomputation. If not specified, we determine the mode as follows:
We use recomputation by default since it incurs lower overhead than We use recomputation by default since it incurs lower overhead than
swapping. However, when the sequence group has multiple sequences swapping. However, when the sequence group has multiple sequences
...@@ -1215,7 +1216,7 @@ class SpeculativeConfig: ...@@ -1215,7 +1216,7 @@ class SpeculativeConfig:
typical_acceptance_sampler_posterior_threshold (Optional[float]): typical_acceptance_sampler_posterior_threshold (Optional[float]):
A threshold value that sets a lower bound on the posterior A threshold value that sets a lower bound on the posterior
probability of a token in the target model for it to be probability of a token in the target model for it to be
accepted. This threshold is used only when we use the accepted. This threshold is used only when we use the
TypicalAcceptanceSampler for token acceptance. TypicalAcceptanceSampler for token acceptance.
typical_acceptance_sampler_posterior_alpha (Optional[float]): typical_acceptance_sampler_posterior_alpha (Optional[float]):
A scaling factor for the entropy-based threshold in the A scaling factor for the entropy-based threshold in the
...@@ -1225,7 +1226,7 @@ class SpeculativeConfig: ...@@ -1225,7 +1226,7 @@ class SpeculativeConfig:
If set to False, token log probabilities are returned If set to False, token log probabilities are returned
according to the log probability settings in SamplingParams. according to the log probability settings in SamplingParams.
If not specified, it defaults to True. If not specified, it defaults to True.
Returns: Returns:
Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if
the necessary conditions are met, else None. the necessary conditions are met, else None.
...@@ -1470,13 +1471,13 @@ class SpeculativeConfig: ...@@ -1470,13 +1471,13 @@ class SpeculativeConfig:
typical_acceptance_sampler_posterior_threshold (Optional[float]): typical_acceptance_sampler_posterior_threshold (Optional[float]):
A threshold value that sets a lower bound on the posterior A threshold value that sets a lower bound on the posterior
probability of a token in the target model for it to be probability of a token in the target model for it to be
accepted. This threshold is used only when we use the accepted. This threshold is used only when we use the
TypicalAcceptanceSampler for token acceptance. TypicalAcceptanceSampler for token acceptance.
typical_acceptance_sampler_posterior_alpha (Optional[float]): typical_acceptance_sampler_posterior_alpha (Optional[float]):
A scaling factor for the entropy-based threshold in the A scaling factor for the entropy-based threshold in the
TypicalAcceptanceSampler. TypicalAcceptanceSampler.
disable_logprobs: If set to True, token log probabilities will not disable_logprobs: If set to True, token log probabilities will not
be returned even if requested by sampling parameters. This be returned even if requested by sampling parameters. This
reduces latency by skipping logprob calculation in proposal reduces latency by skipping logprob calculation in proposal
sampling, target sampling, and after accepted tokens are sampling, target sampling, and after accepted tokens are
determined. If set to False, log probabilities will be determined. If set to False, log probabilities will be
...@@ -1843,10 +1844,10 @@ def get_min_sliding_window( ...@@ -1843,10 +1844,10 @@ def get_min_sliding_window(
def get_served_model_name(model: str, def get_served_model_name(model: str,
served_model_name: Optional[Union[str, List[str]]]): served_model_name: Optional[Union[str, List[str]]]):
""" """
If the input is a non-empty list, the first model_name in If the input is a non-empty list, the first model_name in
`served_model_name` is taken. `served_model_name` is taken.
If the input is a non-empty string, it is used directly. If the input is a non-empty string, it is used directly.
For cases where the input is either an empty string or an For cases where the input is either an empty string or an
empty list, the fallback is to use `self.model`. empty list, the fallback is to use `self.model`.
""" """
if not served_model_name: if not served_model_name:
......
...@@ -10,7 +10,7 @@ from vllm.executor.msgspec_utils import decode_hook, encode_hook ...@@ -10,7 +10,7 @@ from vllm.executor.msgspec_utils import decode_hook, encode_hook
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.sequence import ExecuteModelRequest, IntermediateTensors from vllm.sequence import ExecuteModelRequest, IntermediateTensors
from vllm.utils import get_ip, is_hip from vllm.utils import get_ip
from vllm.worker.worker_base import WorkerWrapperBase from vllm.worker.worker_base import WorkerWrapperBase
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -231,7 +231,7 @@ def initialize_ray_cluster( ...@@ -231,7 +231,7 @@ def initialize_ray_cluster(
assert_ray_available() assert_ray_available()
# Connect to a ray cluster. # Connect to a ray cluster.
if is_hip() or current_platform.is_xpu(): if current_platform.is_rocm() or current_platform.is_xpu():
ray.init(address=ray_address, ray.init(address=ray_address,
ignore_reinit_error=True, ignore_reinit_error=True,
num_gpus=parallel_config.world_size) num_gpus=parallel_config.world_size)
......
...@@ -7,7 +7,7 @@ import vllm.envs as envs ...@@ -7,7 +7,7 @@ import vllm.envs as envs
from vllm.compilation.levels import CompilationLevel from vllm.compilation.levels import CompilationLevel
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import is_hip, print_warning_once from vllm.utils import print_warning_once
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -72,7 +72,7 @@ class CustomOp(nn.Module): ...@@ -72,7 +72,7 @@ class CustomOp(nn.Module):
if not enabled: if not enabled:
return self.forward_native return self.forward_native
if is_hip(): if current_platform.is_rocm():
return self.forward_hip return self.forward_hip
elif current_platform.is_cpu(): elif current_platform.is_cpu():
return self.forward_cpu return self.forward_cpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment