Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4e2d95e3
Unverified
Commit
4e2d95e3
authored
Oct 28, 2024
by
wangshuai09
Committed by
GitHub
Oct 28, 2024
Browse files
[Hardware][ROCM] using current_platform.is_rocm (#9642)
Signed-off-by:
wangshuai09
<
391746016@qq.com
>
parent
34a99416
Changes
32
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
131 additions
and
114 deletions
+131
-114
tests/basic_correctness/test_basic_correctness.py
tests/basic_correctness/test_basic_correctness.py
+2
-2
tests/compile/utils.py
tests/compile/utils.py
+2
-2
tests/kernels/quant_utils.py
tests/kernels/quant_utils.py
+11
-6
tests/kernels/test_attention.py
tests/kernels/test_attention.py
+13
-10
tests/kernels/test_attention_selector.py
tests/kernels/test_attention_selector.py
+2
-1
tests/kernels/test_blocksparse_attention.py
tests/kernels/test_blocksparse_attention.py
+4
-3
tests/kernels/test_encoder_decoder_attn.py
tests/kernels/test_encoder_decoder_attn.py
+40
-36
tests/kernels/test_moe.py
tests/kernels/test_moe.py
+4
-3
tests/lora/test_gemma.py
tests/lora/test_gemma.py
+3
-2
tests/lora/test_quant_model.py
tests/lora/test_quant_model.py
+2
-2
tests/models/decoder_only/vision_language/test_paligemma.py
tests/models/decoder_only/vision_language/test_paligemma.py
+5
-4
tests/models/decoder_only/vision_language/test_phi3v.py
tests/models/decoder_only/vision_language/test_phi3v.py
+1
-2
tests/spec_decode/e2e/test_integration_dist_tp2.py
tests/spec_decode/e2e/test_integration_dist_tp2.py
+2
-2
tests/utils.py
tests/utils.py
+2
-2
vllm/_custom_ops.py
vllm/_custom_ops.py
+4
-4
vllm/attention/ops/blocksparse_attention/interface.py
vllm/attention/ops/blocksparse_attention/interface.py
+3
-3
vllm/attention/selector.py
vllm/attention/selector.py
+2
-2
vllm/config.py
vllm/config.py
+25
-24
vllm/executor/ray_utils.py
vllm/executor/ray_utils.py
+2
-2
vllm/model_executor/custom_op.py
vllm/model_executor/custom_op.py
+2
-2
No files found.
tests/basic_correctness/test_basic_correctness.py
View file @
4e2d95e3
...
...
@@ -11,7 +11,7 @@ from unittest.mock import patch
import
pytest
from
vllm
import
LLM
from
vllm.
util
s
import
is_hip
from
vllm.
platform
s
import
current_platform
from
vllm.worker.model_runner
import
ModelInputForGPUWithSamplingMetadata
from
..models.utils
import
check_outputs_equal
...
...
@@ -51,7 +51,7 @@ def test_models(
enforce_eager
:
bool
,
)
->
None
:
if
backend
==
"FLASHINFER"
and
is_hip
():
if
backend
==
"FLASHINFER"
and
current_platform
.
is_rocm
():
pytest
.
skip
(
"Flashinfer does not support ROCm/HIP."
)
os
.
environ
[
"VLLM_ATTENTION_BACKEND"
]
=
backend
...
...
tests/compile/utils.py
View file @
4e2d95e3
...
...
@@ -5,7 +5,7 @@ import torch
from
tests.quantization.utils
import
is_quant_method_supported
from
vllm
import
LLM
,
SamplingParams
from
vllm.compilation.levels
import
CompilationLevel
from
vllm.
util
s
import
is_hip
from
vllm.
platform
s
import
current_platform
TEST_MODELS
=
[
(
"facebook/opt-125m"
,
{}),
...
...
@@ -55,7 +55,7 @@ if is_quant_method_supported("marlin"):
"quantization"
:
"marlin"
}))
if
not
is_hip
()
and
is_quant_method_supported
(
"awq"
):
if
not
current_platform
.
is_rocm
()
and
is_quant_method_supported
(
"awq"
):
TEST_MODELS
.
append
((
"TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ"
,
{
"quantization"
:
"AWQ"
}))
...
...
tests/kernels/quant_utils.py
View file @
4e2d95e3
...
...
@@ -2,12 +2,13 @@ from typing import Optional, Tuple, Union
import
torch
from
vllm.
util
s
import
is_hip
from
vllm.
platform
s
import
current_platform
# Using the default value (240.0) from pytorch will cause accuracy
# issue on dynamic quantization models. Here use 224.0 for rocm.
ROCM_FP8_MAX
=
224.0
FP8_DTYPE
=
torch
.
float8_e4m3fnuz
if
is_hip
()
else
torch
.
float8_e4m3fn
FP8_DTYPE
=
torch
.
float8_e4m3fnuz
if
current_platform
.
is_rocm
()
\
else
torch
.
float8_e4m3fn
def
as_float32_tensor
(
x
:
Union
[
float
,
torch
.
tensor
])
->
torch
.
tensor
:
...
...
@@ -24,8 +25,10 @@ def ref_dynamic_per_token_quant(x: torch.tensor,
qtype_traits
=
torch
.
iinfo
(
quant_dtype
)
if
quant_dtype
==
torch
.
int8
\
else
torch
.
finfo
(
quant_dtype
)
qtype_traits_max
=
ROCM_FP8_MAX
if
is_hip
()
else
qtype_traits
.
max
qtype_traits_min
=
-
ROCM_FP8_MAX
if
is_hip
()
else
qtype_traits
.
min
qtype_traits_max
=
ROCM_FP8_MAX
if
current_platform
.
is_rocm
()
\
else
qtype_traits
.
max
qtype_traits_min
=
-
ROCM_FP8_MAX
if
current_platform
.
is_rocm
()
\
else
qtype_traits
.
min
qtype_max
=
as_float32_tensor
(
qtype_traits_max
)
s_1
=
as_float32_tensor
(
1.0
)
s_512
=
as_float32_tensor
(
512.0
)
...
...
@@ -66,8 +69,10 @@ def ref_dynamic_per_tensor_fp8_quant(x: torch.tensor) \
->
Tuple
[
torch
.
tensor
,
torch
.
tensor
]:
fp8_traits
=
torch
.
finfo
(
FP8_DTYPE
)
fp8_traits_max
=
ROCM_FP8_MAX
if
is_hip
()
else
fp8_traits
.
max
fp8_traits_min
=
-
ROCM_FP8_MAX
if
is_hip
()
else
fp8_traits
.
min
fp8_traits_max
=
ROCM_FP8_MAX
if
current_platform
.
is_rocm
()
\
else
fp8_traits
.
max
fp8_traits_min
=
-
ROCM_FP8_MAX
if
current_platform
.
is_rocm
()
\
else
fp8_traits
.
min
fp8_max
=
as_float32_tensor
(
fp8_traits_max
)
one
=
as_float32_tensor
(
1.0
)
...
...
tests/kernels/test_attention.py
View file @
4e2d95e3
...
...
@@ -6,11 +6,12 @@ import torch
from
tests.kernels.utils
import
opcheck
from
vllm
import
_custom_ops
as
ops
from
vllm.utils
import
get_max_shared_memory_bytes
,
is_hip
,
seed_everything
from
vllm.platforms
import
current_platform
from
vllm.utils
import
get_max_shared_memory_bytes
,
seed_everything
from
.allclose_default
import
get_default_atol
,
get_default_rtol
if
not
is_hip
():
if
not
current_platform
.
is_rocm
():
from
xformers
import
ops
as
xops
from
xformers.ops.fmha.attn_bias
import
BlockDiagonalCausalMask
...
...
@@ -23,8 +24,9 @@ MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512
NUM_BLOCKS
=
4321
# Arbitrary values for testing
PARTITION_SIZE
=
512
# flshattF and tritonflashattF supported: {torch.float16, torch.bfloat16}
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
if
not
is_hip
()
else
[
torch
.
half
,
torch
.
bfloat16
]
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
if
not
current_platform
.
is_rocm
()
else
[
torch
.
half
,
torch
.
bfloat16
]
NUM_GEN_SEQS
=
[
7
]
# Arbitrary values for testing
NUM_PREFILL_SEQS
=
[
3
]
# Arbitrary values for testing
NUM_HEADS
=
[(
40
,
40
),
(
64
,
8
)]
# Arbitrary values for testing
...
...
@@ -114,7 +116,8 @@ def ref_single_query_cached_kv_attention(
@
pytest
.
mark
.
parametrize
(
"version"
,
[
"v1"
,
"v2"
]
if
not
is_hip
()
else
[
"v1"
,
"v2"
,
"rocm"
])
"version"
,
[
"v1"
,
"v2"
]
if
not
current_platform
.
is_rocm
()
else
[
"v1"
,
"v2"
,
"rocm"
])
@
pytest
.
mark
.
parametrize
(
"num_seqs"
,
NUM_GEN_SEQS
)
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
...
...
@@ -317,8 +320,8 @@ def test_paged_attention(
# NOTE(woosuk): Due to the kernel-level differences in the two
# implementations, there is a small numerical difference in the two
# outputs. Thus, we use a relaxed tolerance for the test.
atol
=
get_default_atol
(
output
)
if
is_hip
()
else
1e-3
rtol
=
get_default_rtol
(
output
)
if
is_hip
()
else
1e-5
atol
=
get_default_atol
(
output
)
if
current_platform
.
is_rocm
()
else
1e-3
rtol
=
get_default_rtol
(
output
)
if
current_platform
.
is_rocm
()
else
1e-5
# NOTE(zhaoyang): FP8 KV Cache will introduce quantization error,
# so we use a relaxed tolerance for the test.
...
...
@@ -368,7 +371,7 @@ def ref_multi_query_kv_attention(
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
skipif
(
is_hip
(),
@
pytest
.
mark
.
skipif
(
current_platform
.
is_rocm
(),
reason
=
"Xformers backend is not supported on ROCm."
)
@
torch
.
inference_mode
()
def
test_multi_query_kv_attention
(
...
...
@@ -425,6 +428,6 @@ def test_multi_query_kv_attention(
scale
,
dtype
,
)
atol
=
get_default_atol
(
output
)
if
is_hip
()
else
1e-3
rtol
=
get_default_rtol
(
output
)
if
is_hip
()
else
1e-5
atol
=
get_default_atol
(
output
)
if
current_platform
.
is_rocm
()
else
1e-3
rtol
=
get_default_rtol
(
output
)
if
current_platform
.
is_rocm
()
else
1e-5
torch
.
testing
.
assert_close
(
output
,
ref_output
,
atol
=
atol
,
rtol
=
rtol
)
tests/kernels/test_attention_selector.py
View file @
4e2d95e3
...
...
@@ -25,7 +25,8 @@ def test_env(name: str, device: str, monkeypatch):
False
)
assert
backend
.
name
==
"TORCH_SDPA"
elif
device
==
"hip"
:
with
patch
(
"vllm.attention.selector.is_hip"
,
return_value
=
True
):
with
patch
(
"vllm.attention.selector.current_platform.is_rocm"
,
return_value
=
True
):
backend
=
which_attn_to_use
(
16
,
torch
.
float16
,
torch
.
float16
,
16
,
False
)
assert
backend
.
name
==
"ROCM_FLASH"
...
...
tests/kernels/test_blocksparse_attention.py
View file @
4e2d95e3
...
...
@@ -7,7 +7,8 @@ import torch
from
vllm
import
_custom_ops
as
ops
from
vllm.attention.ops.blocksparse_attention.interface
import
(
LocalStridedBlockSparseAttn
)
from
vllm.utils
import
get_max_shared_memory_bytes
,
is_hip
,
seed_everything
from
vllm.platforms
import
current_platform
from
vllm.utils
import
get_max_shared_memory_bytes
,
seed_everything
from
.allclose_default
import
get_default_atol
,
get_default_rtol
...
...
@@ -316,8 +317,8 @@ def test_paged_attention(
# NOTE(woosuk): Due to the kernel-level differences in the two
# implementations, there is a small numerical difference in the two
# outputs. Thus, we use a relaxed tolerance for the test.
atol
=
get_default_atol
(
output
)
if
is_hip
()
else
1e-3
rtol
=
get_default_rtol
(
output
)
if
is_hip
()
else
1e-5
atol
=
get_default_atol
(
output
)
if
current_platform
.
is_rocm
()
else
1e-3
rtol
=
get_default_rtol
(
output
)
if
current_platform
.
is_rocm
()
else
1e-5
# NOTE(zhaoyang): FP8 KV Cache will introduce quantization error,
# so we use a relaxed tolerance for the test.
...
...
tests/kernels/test_encoder_decoder_attn.py
View file @
4e2d95e3
...
...
@@ -18,7 +18,7 @@ from vllm.attention import (Attention, AttentionBackend, AttentionMetadata,
from
vllm.attention.backends.utils
import
STR_NOT_IMPL_ENC_DEC_ROCM_HIP
from
vllm.attention.selector
import
(
_Backend
,
global_force_attn_backend_context_manager
)
from
vllm.
util
s
import
is_hip
from
vllm.
platform
s
import
current_platform
# List of support backends for encoder/decoder models
LIST_ENC_DEC_SUPPORTED_BACKENDS
=
[
_Backend
.
XFORMERS
]
...
...
@@ -726,7 +726,8 @@ def _run_encoder_decoder_cross_attention_test(
attn_type
=
attn_type
)
@
pytest
.
mark
.
skipif
(
is_hip
(),
reason
=
STR_NOT_IMPL_ENC_DEC_ROCM_HIP
)
@
pytest
.
mark
.
skipif
(
current_platform
.
is_rocm
(),
reason
=
STR_NOT_IMPL_ENC_DEC_ROCM_HIP
)
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
@
pytest
.
mark
.
parametrize
(
"attn_backend"
,
LIST_ENC_DEC_SUPPORTED_BACKENDS
)
...
...
@@ -755,7 +756,8 @@ def test_encoder_only(
No KV cache is required for encoder-only attention.
Note on ROCm/HIP: currently encoder/decoder models are not supported on
AMD GPUs, therefore this test simply is skipped if is_hip().
AMD GPUs, therefore this test simply is skipped if
current_platform.is_rocm().
This test globally forces an override of the usual backend
auto-selection process, forcing the specific backend-under-test
...
...
@@ -811,7 +813,8 @@ def test_encoder_only(
assert_actual_matches_ideal
(
enc_test_params
,
enc_pckd_act_out
)
@
pytest
.
mark
.
skipif
(
is_hip
(),
reason
=
STR_NOT_IMPL_ENC_DEC_ROCM_HIP
)
@
pytest
.
mark
.
skipif
(
current_platform
.
is_rocm
(),
reason
=
STR_NOT_IMPL_ENC_DEC_ROCM_HIP
)
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
@
pytest
.
mark
.
parametrize
(
"attn_backend"
,
LIST_ENC_DEC_SUPPORTED_BACKENDS
)
...
...
@@ -864,7 +867,8 @@ def test_e2e_enc_dec_attn(
to be utilized.
Note on ROCm/HIP: currently encoder/decoder models are not supported on
AMD GPUs, therefore this test simply is skipped if is_hip().
AMD GPUs, therefore this test simply is skipped if
current_platform.is_rocm().
Note on metadata: there is a single attention metadata structure shared by
all prefill-phase attention operations (encoder, decoder, enc/dec cross),
...
...
tests/kernels/test_moe.py
View file @
4e2d95e3
...
...
@@ -18,8 +18,9 @@ from vllm.model_executor.layers.fused_moe.fused_moe import (
from
vllm.model_executor.layers.quantization.utils.marlin_utils_test
import
(
marlin_quantize
)
from
vllm.model_executor.models.mixtral
import
MixtralMoE
from
vllm.platforms
import
current_platform
from
vllm.scalar_type
import
scalar_types
from
vllm.utils
import
is_hip
,
seed_everything
from
vllm.utils
import
seed_everything
@
pytest
.
mark
.
parametrize
(
"m"
,
[
1024
*
128
,
512
,
222
,
33
,
1
])
...
...
@@ -103,7 +104,7 @@ def test_mixtral_moe(dtype: torch.dtype):
@
pytest
.
mark
.
parametrize
(
"act_order"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"num_bits"
,
[
4
,
8
])
@
pytest
.
mark
.
parametrize
(
"is_k_full"
,
[
True
,
False
])
@
pytest
.
mark
.
skipif
(
is_hip
(),
reason
=
"Skip for rocm"
)
@
pytest
.
mark
.
skipif
(
current_platform
.
is_rocm
(),
reason
=
"Skip for rocm"
)
def
test_fused_marlin_moe
(
m
:
int
,
n
:
int
,
...
...
@@ -256,7 +257,7 @@ def test_fused_marlin_moe(
@
pytest
.
mark
.
parametrize
(
"act_order"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"num_bits"
,
[
4
,
8
])
@
pytest
.
mark
.
parametrize
(
"is_k_full"
,
[
True
,
False
])
@
pytest
.
mark
.
skipif
(
is_hip
(),
reason
=
"Skip for rocm"
)
@
pytest
.
mark
.
skipif
(
current_platform
.
is_rocm
(),
reason
=
"Skip for rocm"
)
def
test_single_marlin_moe_multiply
(
m
:
int
,
n
:
int
,
...
...
tests/lora/test_gemma.py
View file @
4e2d95e3
...
...
@@ -4,7 +4,7 @@ import pytest
import
vllm
from
vllm.lora.request
import
LoRARequest
from
vllm.
util
s
import
is_hip
from
vllm.
platform
s
import
current_platform
MODEL_PATH
=
"google/gemma-7b"
...
...
@@ -31,7 +31,8 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
return
generated_texts
@
pytest
.
mark
.
xfail
(
is_hip
(),
reason
=
"There can be output mismatch on ROCm"
)
@
pytest
.
mark
.
xfail
(
current_platform
.
is_rocm
(),
reason
=
"There can be output mismatch on ROCm"
)
def
test_gemma_lora
(
gemma_lora_files
):
llm
=
vllm
.
LLM
(
MODEL_PATH
,
max_model_len
=
1024
,
...
...
tests/lora/test_quant_model.py
View file @
4e2d95e3
...
...
@@ -8,7 +8,7 @@ import pytest
import
vllm
from
vllm.distributed
import
cleanup_dist_env_and_memory
from
vllm.lora.request
import
LoRARequest
from
vllm.
util
s
import
is_hip
from
vllm.
platform
s
import
current_platform
@
dataclass
...
...
@@ -19,7 +19,7 @@ class ModelWithQuantization:
MODELS
:
List
[
ModelWithQuantization
]
#AWQ quantization is currently not supported in ROCm.
if
is_hip
():
if
current_platform
.
is_rocm
():
MODELS
=
[
ModelWithQuantization
(
model_path
=
"TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ"
,
...
...
tests/models/decoder_only/vision_language/test_paligemma.py
View file @
4e2d95e3
...
...
@@ -6,8 +6,9 @@ from transformers import (AutoConfig, AutoModelForVision2Seq, AutoTokenizer,
BatchEncoding
)
from
vllm.multimodal.utils
import
rescale_image_size
from
vllm.platforms
import
current_platform
from
vllm.sequence
import
SampleLogprobs
from
vllm.utils
import
STR_DTYPE_TO_TORCH_DTYPE
,
is_hip
from
vllm.utils
import
STR_DTYPE_TO_TORCH_DTYPE
from
....conftest
import
IMAGE_ASSETS
,
HfRunner
,
VllmRunner
,
_ImageAssets
from
...utils
import
check_logprobs_close
...
...
@@ -24,7 +25,7 @@ models = ["google/paligemma-3b-mix-224"]
# ROCm Triton FA can run into compilation issues with these models due to,
# excessive use of shared memory. Use other backends in the meantime.
# FIXME (mattwong, gshtrasb, hongxiayan)
if
is_hip
():
if
current_platform
.
is_rocm
():
os
.
environ
[
"VLLM_USE_TRITON_FLASH_ATTN"
]
=
"0"
...
...
@@ -151,7 +152,7 @@ def run_test(
pytest
.
param
(
"float"
,
marks
=
pytest
.
mark
.
skipif
(
is_hip
(),
current_platform
.
is_rocm
(),
reason
=
"ROCm FA does not yet fully support 32-bit precision on PaliGemma"
)
),
"half"
...
...
tests/models/decoder_only/vision_language/test_phi3v.py
View file @
4e2d95e3
...
...
@@ -12,7 +12,6 @@ from vllm.multimodal import MultiModalRegistry
from
vllm.multimodal.utils
import
rescale_image_size
from
vllm.platforms
import
current_platform
from
vllm.sequence
import
SampleLogprobs
from
vllm.utils
import
is_hip
from
....conftest
import
(
IMAGE_ASSETS
,
HfRunner
,
PromptImageInput
,
VllmRunner
,
_ImageAssets
)
...
...
@@ -56,7 +55,7 @@ if current_platform.is_cpu():
# ROCm Triton FA can run into shared memory issues with these models,
# use other backends in the meantime
# FIXME (mattwong, gshtrasb, hongxiayan)
if
is_hip
():
if
current_platform
.
is_rocm
():
os
.
environ
[
"VLLM_USE_TRITON_FLASH_ATTN"
]
=
"0"
...
...
tests/spec_decode/e2e/test_integration_dist_tp2.py
View file @
4e2d95e3
...
...
@@ -5,7 +5,7 @@ tensor parallelism.
import
pytest
import
torch
from
vllm.
util
s
import
is_hip
from
vllm.
platform
s
import
current_platform
from
.conftest
import
run_equality_correctness_test_tp
...
...
@@ -51,7 +51,7 @@ def test_target_model_tp_gt_1(common_llm_kwargs, per_test_common_llm_kwargs,
batch_size
:
int
,
output_len
:
int
,
seed
:
int
):
"""Verify greedy equality when tensor parallelism is used.
"""
if
is_hip
():
if
current_platform
.
is_rocm
():
pytest
.
skip
(
"hip is not well-supported yet"
)
run_equality_correctness_test_tp
(
"JackFram/llama-68m"
,
common_llm_kwargs
,
...
...
tests/utils.py
View file @
4e2d95e3
...
...
@@ -26,7 +26,7 @@ from vllm.model_executor.model_loader.loader import get_model_loader
from
vllm.platforms
import
current_platform
from
vllm.transformers_utils.tokenizer
import
get_tokenizer
from
vllm.utils
import
(
FlexibleArgumentParser
,
GB_bytes
,
cuda_device_count_stateless
,
get_open_port
,
is_hip
)
cuda_device_count_stateless
,
get_open_port
)
if
current_platform
.
is_rocm
():
from
amdsmi
import
(
amdsmi_get_gpu_vram_usage
,
...
...
@@ -487,7 +487,7 @@ def wait_for_gpu_memory_to_clear(devices: List[int],
output
:
Dict
[
int
,
str
]
=
{}
output_raw
:
Dict
[
int
,
float
]
=
{}
for
device
in
devices
:
if
is_hip
():
if
current_platform
.
is_rocm
():
dev_handle
=
amdsmi_get_processor_handles
()[
device
]
mem_info
=
amdsmi_get_gpu_vram_usage
(
dev_handle
)
gb_used
=
mem_info
[
"vram_used"
]
/
2
**
10
...
...
vllm/_custom_ops.py
View file @
4e2d95e3
...
...
@@ -674,8 +674,8 @@ def scaled_fp8_quant(
assert
(
input
.
ndim
==
2
)
shape
:
Union
[
Tuple
[
int
,
int
],
torch
.
Size
]
=
input
.
shape
# For rocm, the output fp8 dtype is torch.float_e3m3fnuz
out_dtype
:
torch
.
dtype
=
torch
.
float8_e4m3fnuz
if
vllm
.
utils
.
is_hip
()
\
else
torch
.
float8_e4m3fn
out_dtype
:
torch
.
dtype
=
torch
.
float8_e4m3fnuz
\
if
current_platform
.
is_rocm
()
else
torch
.
float8_e4m3fn
if
num_token_padding
:
shape
=
(
max
(
num_token_padding
,
input
.
shape
[
0
]),
shape
[
1
])
output
=
torch
.
empty
(
shape
,
device
=
input
.
device
,
dtype
=
out_dtype
)
...
...
vllm/attention/ops/blocksparse_attention/interface.py
View file @
4e2d95e3
...
...
@@ -3,7 +3,6 @@ import math
import
torch
from
vllm.platforms
import
current_platform
from
vllm.utils
import
is_hip
from
.utils
import
(
dense_to_crow_col
,
get_head_sliding_step
,
get_sparse_attn_mask
)
...
...
@@ -32,7 +31,8 @@ class LocalStridedBlockSparseAttn(torch.nn.Module):
):
super
().
__init__
()
if
use_spda
is
None
:
use_spda
=
is_hip
()
or
current_platform
.
is_cpu
()
or
not
\
use_spda
=
current_platform
.
is_rocm
()
or
\
current_platform
.
is_cpu
()
or
not
\
IS_COMPUTE_8_OR_ABOVE
device
=
device
or
(
torch
.
cuda
.
current_device
()
if
current_platform
.
is_cuda_alike
()
else
"cpu"
)
...
...
vllm/attention/selector.py
View file @
4e2d95e3
...
...
@@ -10,7 +10,7 @@ import vllm.envs as envs
from
vllm.attention.backends.abstract
import
AttentionBackend
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.utils
import
STR_BACKEND_ENV_VAR
,
is_hip
from
vllm.utils
import
STR_BACKEND_ENV_VAR
logger
=
init_logger
(
__name__
)
...
...
@@ -208,7 +208,7 @@ def which_attn_to_use(
logger
.
info
(
"Cannot use %s backend on TPU."
,
selected_backend
)
return
_Backend
.
PALLAS
if
is_hip
():
if
current_platform
.
is_rocm
():
# AMD GPUs.
selected_backend
=
(
_Backend
.
ROCM_FLASH
if
selected_backend
==
_Backend
.
FLASH_ATTN
else
selected_backend
)
...
...
vllm/config.py
View file @
4e2d95e3
...
...
@@ -17,7 +17,7 @@ from vllm.transformers_utils.config import (ConfigFormat, get_config,
get_hf_image_processor_config
,
get_hf_text_config
)
from
vllm.utils
import
(
GiB_bytes
,
cuda_device_count_stateless
,
get_cpu_memory
,
is_hip
,
print_warning_once
)
print_warning_once
)
if
TYPE_CHECKING
:
from
ray.util.placement_group
import
PlacementGroup
...
...
@@ -350,7 +350,7 @@ class ModelConfig:
raise
ValueError
(
f
"Unknown quantization method:
{
self
.
quantization
}
. Must "
f
"be one of
{
supported_quantization
}
."
)
if
is_hip
(
if
current_platform
.
is_rocm
(
)
and
self
.
quantization
not
in
rocm_supported_quantization
:
raise
ValueError
(
f
"
{
self
.
quantization
}
quantization is currently not "
...
...
@@ -365,7 +365,7 @@ class ModelConfig:
"%s quantization is not fully "
"optimized yet. The speed can be slower than "
"non-quantized models."
,
self
.
quantization
)
if
(
self
.
quantization
==
"awq"
and
is_hip
()
if
(
self
.
quantization
==
"awq"
and
current_platform
.
is_rocm
()
and
not
envs
.
VLLM_USE_TRITON_AWQ
):
logger
.
warning
(
"Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ"
...
...
@@ -843,7 +843,8 @@ class LoadConfig:
self
.
load_format
=
LoadFormat
(
load_format
)
rocm_not_supported_load_format
:
List
[
str
]
=
[]
if
is_hip
()
and
load_format
in
rocm_not_supported_load_format
:
if
current_platform
.
is_rocm
(
)
and
load_format
in
rocm_not_supported_load_format
:
rocm_supported_load_format
=
[
f
for
f
in
LoadFormat
.
__members__
if
(
f
not
in
rocm_not_supported_load_format
)
...
...
@@ -967,7 +968,7 @@ class ParallelConfig:
if
self
.
use_ray
:
from
vllm.executor
import
ray_utils
ray_utils
.
assert_ray_available
()
if
is_hip
():
if
current_platform
.
is_rocm
():
self
.
disable_custom_all_reduce
=
True
logger
.
info
(
"Disabled the custom all-reduce kernel because it is not "
...
...
vllm/executor/ray_utils.py
View file @
4e2d95e3
...
...
@@ -10,7 +10,7 @@ from vllm.executor.msgspec_utils import decode_hook, encode_hook
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.sequence
import
ExecuteModelRequest
,
IntermediateTensors
from
vllm.utils
import
get_ip
,
is_hip
from
vllm.utils
import
get_ip
from
vllm.worker.worker_base
import
WorkerWrapperBase
logger
=
init_logger
(
__name__
)
...
...
@@ -231,7 +231,7 @@ def initialize_ray_cluster(
assert_ray_available
()
# Connect to a ray cluster.
if
is_hip
()
or
current_platform
.
is_xpu
():
if
current_platform
.
is_rocm
()
or
current_platform
.
is_xpu
():
ray
.
init
(
address
=
ray_address
,
ignore_reinit_error
=
True
,
num_gpus
=
parallel_config
.
world_size
)
...
...
vllm/model_executor/custom_op.py
View file @
4e2d95e3
...
...
@@ -7,7 +7,7 @@ import vllm.envs as envs
from
vllm.compilation.levels
import
CompilationLevel
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.utils
import
is_hip
,
print_warning_once
from
vllm.utils
import
print_warning_once
logger
=
init_logger
(
__name__
)
...
...
@@ -72,7 +72,7 @@ class CustomOp(nn.Module):
if
not
enabled
:
return
self
.
forward_native
if
is_hip
():
if
current_platform
.
is_rocm
():
return
self
.
forward_hip
elif
current_platform
.
is_cpu
():
return
self
.
forward_cpu
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment