Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
6ffa3f31
Unverified
Commit
6ffa3f31
authored
Sep 18, 2024
by
Cyrus Leung
Committed by
GitHub
Sep 18, 2024
Browse files
[CI/Build] Avoid CUDA initialization (#8534)
parent
e3515729
Changes
55
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
55 additions
and
87 deletions
+55
-87
tests/kernels/test_mamba_ssm.py
tests/kernels/test_mamba_ssm.py
+3
-2
tests/kernels/test_moe.py
tests/kernels/test_moe.py
+2
-1
tests/kernels/test_pos_encoding.py
tests/kernels/test_pos_encoding.py
+5
-9
tests/kernels/test_prefix_prefill.py
tests/kernels/test_prefix_prefill.py
+3
-9
tests/lora/test_layers.py
tests/lora/test_layers.py
+2
-3
tests/lora/test_punica_sizes.py
tests/lora/test_punica_sizes.py
+5
-13
tests/lora/test_punica_variation.py
tests/lora/test_punica_variation.py
+5
-13
tests/models/decoder_only/language/test_granite.py
tests/models/decoder_only/language/test_granite.py
+2
-7
tests/quantization/test_fp8.py
tests/quantization/test_fp8.py
+1
-3
tests/quantization/utils.py
tests/quantization/utils.py
+5
-3
vllm/attention/backends/rocm_flash_attn.py
vllm/attention/backends/rocm_flash_attn.py
+2
-1
vllm/attention/ops/blocksparse_attention/interface.py
vllm/attention/ops/blocksparse_attention/interface.py
+2
-3
vllm/attention/ops/prefix_prefill.py
vllm/attention/ops/prefix_prefill.py
+1
-2
vllm/attention/selector.py
vllm/attention/selector.py
+2
-2
vllm/config.py
vllm/config.py
+6
-6
vllm/distributed/parallel_state.py
vllm/distributed/parallel_state.py
+2
-1
vllm/envs.py
vllm/envs.py
+1
-0
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
...ers/quantization/compressed_tensors/compressed_tensors.py
+3
-3
vllm/model_executor/layers/quantization/fbgemm_fp8.py
vllm/model_executor/layers/quantization/fbgemm_fp8.py
+1
-3
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+2
-3
No files found.
tests/kernels/test_mamba_ssm.py
View file @
6ffa3f31
...
@@ -5,6 +5,7 @@ from einops import rearrange, repeat
...
@@ -5,6 +5,7 @@ from einops import rearrange, repeat
from
vllm.model_executor.layers.mamba.ops.mamba_ssm
import
(
from
vllm.model_executor.layers.mamba.ops.mamba_ssm
import
(
selective_scan_fn
,
selective_state_update
)
selective_scan_fn
,
selective_state_update
)
from
vllm.utils
import
seed_everything
def
selective_state_update_ref
(
state
,
def
selective_state_update_ref
(
state
,
...
@@ -186,7 +187,7 @@ def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D,
...
@@ -186,7 +187,7 @@ def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D,
rtolw
=
max
(
rtolw
,
rtol
)
rtolw
=
max
(
rtolw
,
rtol
)
atolw
=
max
(
atolw
,
atol
)
atolw
=
max
(
atolw
,
atol
)
# set seed
# set seed
torch
.
random
.
manual_seed
(
0
)
seed_everything
(
0
)
batch_size
=
2
batch_size
=
2
dim
=
4
dim
=
4
dstate
=
8
dstate
=
8
...
@@ -287,7 +288,7 @@ def test_selective_state_update(dim, dstate, has_z, itype):
...
@@ -287,7 +288,7 @@ def test_selective_state_update(dim, dstate, has_z, itype):
if
torch
.
version
.
hip
:
if
torch
.
version
.
hip
:
atol
*=
2
atol
*=
2
# set seed
# set seed
torch
.
random
.
manual_seed
(
0
)
seed_everything
(
0
)
batch_size
=
1
batch_size
=
1
state
=
torch
.
randn
(
batch_size
,
dim
,
dstate
,
dtype
=
itype
,
device
=
device
)
state
=
torch
.
randn
(
batch_size
,
dim
,
dstate
,
dtype
=
itype
,
device
=
device
)
x
=
torch
.
randn
(
batch_size
,
dim
,
device
=
device
,
dtype
=
itype
)
x
=
torch
.
randn
(
batch_size
,
dim
,
device
=
device
,
dtype
=
itype
)
...
...
tests/kernels/test_moe.py
View file @
6ffa3f31
...
@@ -18,6 +18,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
...
@@ -18,6 +18,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
marlin_quantize
)
marlin_quantize
)
from
vllm.model_executor.models.mixtral
import
MixtralMoE
from
vllm.model_executor.models.mixtral
import
MixtralMoE
from
vllm.scalar_type
import
scalar_types
from
vllm.scalar_type
import
scalar_types
from
vllm.utils
import
seed_everything
def
torch_moe
(
a
,
w1
,
w2
,
score
,
topk
):
def
torch_moe
(
a
,
w1
,
w2
,
score
,
topk
):
...
@@ -151,7 +152,7 @@ def test_fused_marlin_moe(
...
@@ -151,7 +152,7 @@ def test_fused_marlin_moe(
act_order
:
bool
,
act_order
:
bool
,
num_bits
:
int
,
num_bits
:
int
,
):
):
torch
.
manual_seed
(
7
)
seed_everything
(
7
)
if
topk
>
e
:
if
topk
>
e
:
return
return
...
...
tests/kernels/test_pos_encoding.py
View file @
6ffa3f31
...
@@ -5,6 +5,7 @@ import pytest
...
@@ -5,6 +5,7 @@ import pytest
import
torch
import
torch
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.utils
import
seed_everything
from
.allclose_default
import
get_default_atol
,
get_default_rtol
from
.allclose_default
import
get_default_atol
,
get_default_rtol
...
@@ -46,9 +47,8 @@ def test_rotary_embedding(
...
@@ -46,9 +47,8 @@ def test_rotary_embedding(
)
->
None
:
)
->
None
:
if
rotary_dim
is
None
:
if
rotary_dim
is
None
:
rotary_dim
=
head_size
rotary_dim
=
head_size
torch
.
random
.
manual_seed
(
seed
)
if
torch
.
cuda
.
is_available
():
seed_everything
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
if
rotary_dim
is
None
:
if
rotary_dim
is
None
:
rotary_dim
=
head_size
rotary_dim
=
head_size
...
@@ -100,9 +100,7 @@ def test_batched_rotary_embedding(
...
@@ -100,9 +100,7 @@ def test_batched_rotary_embedding(
max_position
:
int
=
8192
,
max_position
:
int
=
8192
,
base
:
int
=
10000
,
base
:
int
=
10000
,
)
->
None
:
)
->
None
:
torch
.
random
.
manual_seed
(
seed
)
seed_everything
(
seed
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed
(
seed
)
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
if
rotary_dim
is
None
:
if
rotary_dim
is
None
:
rotary_dim
=
head_size
rotary_dim
=
head_size
...
@@ -162,9 +160,7 @@ def test_batched_rotary_embedding_multi_lora(
...
@@ -162,9 +160,7 @@ def test_batched_rotary_embedding_multi_lora(
max_position
:
int
=
8192
,
max_position
:
int
=
8192
,
base
:
int
=
10000
,
base
:
int
=
10000
,
)
->
None
:
)
->
None
:
torch
.
random
.
manual_seed
(
seed
)
seed_everything
(
seed
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed
(
seed
)
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
if
rotary_dim
is
None
:
if
rotary_dim
is
None
:
rotary_dim
=
head_size
rotary_dim
=
head_size
...
...
tests/kernels/test_prefix_prefill.py
View file @
6ffa3f31
...
@@ -9,7 +9,7 @@ from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask
...
@@ -9,7 +9,7 @@ from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask
from
vllm.attention.backends.xformers
import
_make_alibi_bias
from
vllm.attention.backends.xformers
import
_make_alibi_bias
from
vllm.attention.ops.prefix_prefill
import
context_attention_fwd
from
vllm.attention.ops.prefix_prefill
import
context_attention_fwd
from
vllm.utils
import
STR_DTYPE_TO_TORCH_DTYPE
from
vllm.utils
import
STR_DTYPE_TO_TORCH_DTYPE
,
seed_everything
NUM_HEADS
=
[
64
]
NUM_HEADS
=
[
64
]
NUM_QUERIES_PER_KV
=
[
1
,
8
,
64
]
NUM_QUERIES_PER_KV
=
[
1
,
8
,
64
]
...
@@ -39,10 +39,7 @@ def test_contexted_kv_attention(
...
@@ -39,10 +39,7 @@ def test_contexted_kv_attention(
kv_cache_dtype
:
str
,
kv_cache_dtype
:
str
,
device
:
str
,
device
:
str
,
)
->
None
:
)
->
None
:
random
.
seed
(
0
)
seed_everything
(
0
)
torch
.
manual_seed
(
0
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed
(
0
)
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
# Need this, otherwise when we capture the graph the process
# Need this, otherwise when we capture the graph the process
...
@@ -237,10 +234,7 @@ def test_contexted_kv_attention_alibi(
...
@@ -237,10 +234,7 @@ def test_contexted_kv_attention_alibi(
kv_cache_dtype
:
str
,
kv_cache_dtype
:
str
,
device
:
str
,
device
:
str
,
)
->
None
:
)
->
None
:
random
.
seed
(
0
)
seed_everything
(
0
)
torch
.
manual_seed
(
0
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed
(
0
)
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
# Need this, otherwise when we capture the graph the process
# Need this, otherwise when we capture the graph the process
...
...
tests/lora/test_layers.py
View file @
6ffa3f31
...
@@ -39,6 +39,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
...
@@ -39,6 +39,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
,
get_masked_input_and_mask
)
ParallelLMHead
,
VocabParallelEmbedding
,
get_masked_input_and_mask
)
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.utils
import
seed_everything
from
.utils
import
DummyLoRAManager
from
.utils
import
DummyLoRAManager
...
@@ -922,9 +923,7 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device,
...
@@ -922,9 +923,7 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device,
seq_len
)
->
None
:
seq_len
)
->
None
:
dtype
=
torch
.
float16
dtype
=
torch
.
float16
seed
=
0
seed
=
0
torch
.
random
.
manual_seed
(
seed
)
seed_everything
(
seed
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed
(
seed
)
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
punica_wrapper
=
PunicaWrapper
(
8192
,
256
,
device
)
punica_wrapper
=
PunicaWrapper
(
8192
,
256
,
device
)
max_loras
=
8
max_loras
=
8
...
...
tests/lora/test_punica_sizes.py
View file @
6ffa3f31
...
@@ -4,7 +4,6 @@ hidden_sizes included in the LoRA models currently supported by vLLM. It tests
...
@@ -4,7 +4,6 @@ hidden_sizes included in the LoRA models currently supported by vLLM. It tests
whether the corresponding Triton kernel can run normally when tensor parallelism
whether the corresponding Triton kernel can run normally when tensor parallelism
is set to [1, 2, 4, 8, 16, 32, 64].
is set to [1, 2, 4, 8, 16, 32, 64].
"""
"""
import
random
from
unittest.mock
import
patch
from
unittest.mock
import
patch
import
pytest
import
pytest
...
@@ -17,6 +16,7 @@ from vllm.lora.ops.sgmv_expand import sgmv_expand
...
@@ -17,6 +16,7 @@ from vllm.lora.ops.sgmv_expand import sgmv_expand
from
vllm.lora.ops.sgmv_expand_slice
import
sgmv_expand_slice
from
vllm.lora.ops.sgmv_expand_slice
import
sgmv_expand_slice
from
vllm.lora.ops.sgmv_shrink
import
sgmv_shrink
from
vllm.lora.ops.sgmv_shrink
import
sgmv_shrink
from
vllm.triton_utils.libentry
import
LibEntry
from
vllm.triton_utils.libentry
import
LibEntry
from
vllm.utils
import
seed_everything
from
.utils
import
(
generate_data
,
generate_data_for_expand_nslices
,
from
.utils
import
(
generate_data
,
generate_data_for_expand_nslices
,
ref_torch_groupgemm
)
ref_torch_groupgemm
)
...
@@ -145,11 +145,8 @@ def test_punica_sgmv(
...
@@ -145,11 +145,8 @@ def test_punica_sgmv(
seed
:
int
,
seed
:
int
,
device
:
str
,
device
:
str
,
):
):
random
.
seed
(
seed
)
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
torch
.
random
.
manual_seed
(
seed
)
seed_everything
(
seed
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed
(
seed
)
seq_length
=
128
seq_length
=
128
(
(
...
@@ -238,11 +235,8 @@ def test_punica_bgmv(
...
@@ -238,11 +235,8 @@ def test_punica_bgmv(
from
vllm.lora.ops.bgmv_expand
import
_bgmv_expand_kernel
from
vllm.lora.ops.bgmv_expand
import
_bgmv_expand_kernel
from
vllm.lora.ops.bgmv_shrink
import
_bgmv_shrink_kernel
from
vllm.lora.ops.bgmv_shrink
import
_bgmv_shrink_kernel
random
.
seed
(
seed
)
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
torch
.
random
.
manual_seed
(
seed
)
seed_everything
(
seed
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed
(
seed
)
seq_length
=
1
seq_length
=
1
(
(
...
@@ -329,11 +323,9 @@ def test_punica_expand_nslices(
...
@@ -329,11 +323,9 @@ def test_punica_expand_nslices(
):
):
from
vllm.lora.ops.bgmv_expand_slice
import
_bgmv_expand_slice_kernel
from
vllm.lora.ops.bgmv_expand_slice
import
_bgmv_expand_slice_kernel
random
.
seed
(
seed
)
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
torch
.
random
.
manual_seed
(
seed
)
seed_everything
(
seed
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed
(
seed
)
seq_length
=
128
if
op_type
==
"sgmv"
else
1
seq_length
=
128
if
op_type
==
"sgmv"
else
1
(
(
inputs_tensor
,
inputs_tensor
,
...
...
tests/lora/test_punica_variation.py
View file @
6ffa3f31
...
@@ -3,7 +3,6 @@ This script is mainly used to test whether trtion kernels can run normally
...
@@ -3,7 +3,6 @@ This script is mainly used to test whether trtion kernels can run normally
under different conditions, including various batches, numbers of LoRA , and
under different conditions, including various batches, numbers of LoRA , and
maximum ranks.
maximum ranks.
"""
"""
import
random
from
unittest.mock
import
patch
from
unittest.mock
import
patch
import
pytest
import
pytest
...
@@ -16,6 +15,7 @@ from vllm.lora.ops.sgmv_expand import sgmv_expand
...
@@ -16,6 +15,7 @@ from vllm.lora.ops.sgmv_expand import sgmv_expand
from
vllm.lora.ops.sgmv_expand_slice
import
sgmv_expand_slice
from
vllm.lora.ops.sgmv_expand_slice
import
sgmv_expand_slice
from
vllm.lora.ops.sgmv_shrink
import
sgmv_shrink
from
vllm.lora.ops.sgmv_shrink
import
sgmv_shrink
from
vllm.triton_utils.libentry
import
LibEntry
from
vllm.triton_utils.libentry
import
LibEntry
from
vllm.utils
import
seed_everything
from
.utils
import
(
generate_data
,
generate_data_for_expand_nslices
,
from
.utils
import
(
generate_data
,
generate_data_for_expand_nslices
,
ref_torch_groupgemm
)
ref_torch_groupgemm
)
...
@@ -60,11 +60,8 @@ def test_punica_sgmv(
...
@@ -60,11 +60,8 @@ def test_punica_sgmv(
seed
:
int
,
seed
:
int
,
device
:
str
,
device
:
str
,
):
):
random
.
seed
(
seed
)
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
torch
.
random
.
manual_seed
(
seed
)
seed_everything
(
seed
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed
(
seed
)
seq_length
=
128
seq_length
=
128
(
(
...
@@ -153,11 +150,8 @@ def test_punica_bgmv(
...
@@ -153,11 +150,8 @@ def test_punica_bgmv(
from
vllm.lora.ops.bgmv_expand
import
_bgmv_expand_kernel
from
vllm.lora.ops.bgmv_expand
import
_bgmv_expand_kernel
from
vllm.lora.ops.bgmv_shrink
import
_bgmv_shrink_kernel
from
vllm.lora.ops.bgmv_shrink
import
_bgmv_shrink_kernel
random
.
seed
(
seed
)
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
torch
.
random
.
manual_seed
(
seed
)
seed_everything
(
seed
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed
(
seed
)
seq_length
=
1
seq_length
=
1
(
(
...
@@ -244,11 +238,9 @@ def test_punica_expand_nslices(
...
@@ -244,11 +238,9 @@ def test_punica_expand_nslices(
):
):
from
vllm.lora.ops.bgmv_expand_slice
import
_bgmv_expand_slice_kernel
from
vllm.lora.ops.bgmv_expand_slice
import
_bgmv_expand_slice_kernel
random
.
seed
(
seed
)
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
torch
.
random
.
manual_seed
(
seed
)
seed_everything
(
seed
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed
(
seed
)
seq_length
=
128
if
op_type
==
"sgmv"
else
1
seq_length
=
128
if
op_type
==
"sgmv"
else
1
(
(
inputs_tensor
,
inputs_tensor
,
...
...
tests/models/decoder_only/language/test_granite.py
View file @
6ffa3f31
...
@@ -2,23 +2,18 @@
...
@@ -2,23 +2,18 @@
Run `pytest tests/models/test_granite.py`.
Run `pytest tests/models/test_granite.py`.
"""
"""
import
importlib.metadata
import
pytest
import
pytest
import
transformers
from
...utils
import
check_logprobs_close
from
...utils
import
check_logprobs_close
TRANSFORMERS_VERSION
=
tuple
(
map
(
int
,
importlib
.
metadata
.
version
(
"transformers"
).
split
(
"."
)))
MODELS
=
[
MODELS
=
[
"ibm/PowerLM-3b"
,
"ibm/PowerLM-3b"
,
]
]
# GraniteForCausalLM will be in transformers >= 4.45
# GraniteForCausalLM will be in transformers >= 4.45
@
pytest
.
mark
.
skipif
(
TRANSFORMERS_VERSION
<
(
4
,
45
)
,
@
pytest
.
mark
.
skipif
(
transformers
.
__version__
<
"4.
45
"
,
reason
=
"granite model test requires transformers >= 4.45"
)
reason
=
"granite model test requires transformers >= 4.45"
)
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"bfloat16"
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"bfloat16"
])
...
...
tests/quantization/test_fp8.py
View file @
6ffa3f31
...
@@ -86,9 +86,7 @@ def test_load_fp16_model(vllm_runner, kv_cache_dtype: str, force_marlin: bool,
...
@@ -86,9 +86,7 @@ def test_load_fp16_model(vllm_runner, kv_cache_dtype: str, force_marlin: bool,
assert
attn
.
_k_scale
==
1.0
assert
attn
.
_k_scale
==
1.0
assert
attn
.
_v_scale
==
1.0
assert
attn
.
_v_scale
==
1.0
capability
=
current_platform
.
get_device_capability
()
if
current_platform
.
has_device_capability
(
89
)
and
not
force_marlin
:
capability
=
capability
[
0
]
*
10
+
capability
[
1
]
if
capability
>=
89
and
not
force_marlin
:
# For GPUs with hardware support, we keep weights in fp8
# For GPUs with hardware support, we keep weights in fp8
assert
fc1
.
weight
.
dtype
==
torch
.
float8_e4m3fn
assert
fc1
.
weight
.
dtype
==
torch
.
float8_e4m3fn
else
:
else
:
...
...
tests/quantization/utils.py
View file @
6ffa3f31
...
@@ -8,6 +8,8 @@ def is_quant_method_supported(quant_method: str) -> bool:
...
@@ -8,6 +8,8 @@ def is_quant_method_supported(quant_method: str) -> bool:
return
False
return
False
capability
=
current_platform
.
get_device_capability
()
capability
=
current_platform
.
get_device_capability
()
capability
=
capability
[
0
]
*
10
+
capability
[
1
]
assert
capability
is
not
None
return
(
capability
>=
QUANTIZATION_METHODS
[
quant_method
].
get_min_capability
())
min_capability
=
QUANTIZATION_METHODS
[
quant_method
].
get_min_capability
()
return
capability
.
to_int
()
>=
min_capability
vllm/attention/backends/rocm_flash_attn.py
View file @
6ffa3f31
...
@@ -13,6 +13,7 @@ from vllm.attention.backends.utils import (CommonAttentionState,
...
@@ -13,6 +13,7 @@ from vllm.attention.backends.utils import (CommonAttentionState,
from
vllm.attention.ops.paged_attn
import
(
PagedAttention
,
from
vllm.attention.ops.paged_attn
import
(
PagedAttention
,
PagedAttentionMetadata
)
PagedAttentionMetadata
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -299,7 +300,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
...
@@ -299,7 +300,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
else
:
else
:
# if not using triton, navi3x/navi21/navi10 do not use flash-attn
# if not using triton, navi3x/navi21/navi10 do not use flash-attn
# either
# either
if
torch
.
cuda
.
get
_device_capability
(
)[
0
]
!=
9
:
if
not
current_platform
.
has
_device_capability
(
90
)
:
self
.
use_naive_attn
=
True
self
.
use_naive_attn
=
True
else
:
else
:
try
:
try
:
...
...
vllm/attention/ops/blocksparse_attention/interface.py
View file @
6ffa3f31
...
@@ -8,8 +8,7 @@ from vllm.utils import is_cpu, is_hip
...
@@ -8,8 +8,7 @@ from vllm.utils import is_cpu, is_hip
from
.utils
import
(
dense_to_crow_col
,
get_head_sliding_step
,
from
.utils
import
(
dense_to_crow_col
,
get_head_sliding_step
,
get_sparse_attn_mask
)
get_sparse_attn_mask
)
IS_COMPUTE_8_OR_ABOVE
=
(
torch
.
cuda
.
is_available
()
IS_COMPUTE_8_OR_ABOVE
=
current_platform
.
has_device_capability
(
80
)
and
current_platform
.
get_device_capability
()[
0
]
>=
8
)
if
IS_COMPUTE_8_OR_ABOVE
:
if
IS_COMPUTE_8_OR_ABOVE
:
from
.blocksparse_attention_kernel
import
blocksparse_flash_attn_varlen_fwd
from
.blocksparse_attention_kernel
import
blocksparse_flash_attn_varlen_fwd
...
@@ -36,7 +35,7 @@ class LocalStridedBlockSparseAttn(torch.nn.Module):
...
@@ -36,7 +35,7 @@ class LocalStridedBlockSparseAttn(torch.nn.Module):
use_spda
=
is_hip
()
or
is_cpu
()
or
not
\
use_spda
=
is_hip
()
or
is_cpu
()
or
not
\
IS_COMPUTE_8_OR_ABOVE
IS_COMPUTE_8_OR_ABOVE
device
=
device
or
(
torch
.
cuda
.
current_device
()
device
=
device
or
(
torch
.
cuda
.
current_device
()
if
torch
.
cuda
.
is_availabl
e
()
else
"cpu"
)
if
current_platform
.
is_cuda_alik
e
()
else
"cpu"
)
device
=
torch
.
device
(
device
)
device
=
torch
.
device
(
device
)
# NOTE: vllm CPU backend support BF16 instead of FP16.
# NOTE: vllm CPU backend support BF16 instead of FP16.
dtype
=
dtype
or
(
torch
.
bfloat16
if
IS_COMPUTE_8_OR_ABOVE
dtype
=
dtype
or
(
torch
.
bfloat16
if
IS_COMPUTE_8_OR_ABOVE
...
...
vllm/attention/ops/prefix_prefill.py
View file @
6ffa3f31
...
@@ -709,8 +709,7 @@ if triton.__version__ >= "2.1.0":
...
@@ -709,8 +709,7 @@ if triton.__version__ >= "2.1.0":
alibi_slopes
=
None
,
alibi_slopes
=
None
,
sliding_window
=
None
):
sliding_window
=
None
):
cap
=
current_platform
.
get_device_capability
()
BLOCK
=
128
if
current_platform
.
has_device_capability
(
80
)
else
64
BLOCK
=
128
if
cap
[
0
]
>=
8
else
64
NUM_WARPS
=
8
NUM_WARPS
=
8
# need to reduce num. blocks when using fp32
# need to reduce num. blocks when using fp32
...
...
vllm/attention/selector.py
View file @
6ffa3f31
...
@@ -203,7 +203,7 @@ def which_attn_to_use(
...
@@ -203,7 +203,7 @@ def which_attn_to_use(
selected_backend
=
(
_Backend
.
ROCM_FLASH
if
selected_backend
selected_backend
=
(
_Backend
.
ROCM_FLASH
if
selected_backend
==
_Backend
.
FLASH_ATTN
else
selected_backend
)
==
_Backend
.
FLASH_ATTN
else
selected_backend
)
if
selected_backend
==
_Backend
.
ROCM_FLASH
:
if
selected_backend
==
_Backend
.
ROCM_FLASH
:
if
current_platform
.
get
_device_capability
(
)[
0
]
!=
9
:
if
not
current_platform
.
has
_device_capability
(
90
)
:
# not Instinct series GPUs.
# not Instinct series GPUs.
logger
.
info
(
"flash_attn is not supported on NAVI GPUs."
)
logger
.
info
(
"flash_attn is not supported on NAVI GPUs."
)
else
:
else
:
...
@@ -212,7 +212,7 @@ def which_attn_to_use(
...
@@ -212,7 +212,7 @@ def which_attn_to_use(
# FlashAttn in NVIDIA GPUs.
# FlashAttn in NVIDIA GPUs.
if
selected_backend
==
_Backend
.
FLASH_ATTN
:
if
selected_backend
==
_Backend
.
FLASH_ATTN
:
if
current_platform
.
get
_device_capability
(
)[
0
]
<
8
:
if
not
current_platform
.
has
_device_capability
(
80
)
:
# Volta and Turing NVIDIA GPUs.
# Volta and Turing NVIDIA GPUs.
logger
.
info
(
logger
.
info
(
"Cannot use FlashAttention-2 backend for Volta and Turing "
"Cannot use FlashAttention-2 backend for Volta and Turing "
...
...
vllm/config.py
View file @
6ffa3f31
...
@@ -17,7 +17,7 @@ from vllm.transformers_utils.config import (ConfigFormat, get_config,
...
@@ -17,7 +17,7 @@ from vllm.transformers_utils.config import (ConfigFormat, get_config,
get_hf_image_processor_config
,
get_hf_image_processor_config
,
get_hf_text_config
)
get_hf_text_config
)
from
vllm.utils
import
(
GiB_bytes
,
cuda_device_count_stateless
,
get_cpu_memory
,
from
vllm.utils
import
(
GiB_bytes
,
cuda_device_count_stateless
,
get_cpu_memory
,
is_cpu
,
is_hip
,
is_neuron
,
is_openvino
,
is_xpu
,
is_hip
,
is_neuron
,
is_openvino
,
is_xpu
,
print_warning_once
)
print_warning_once
)
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
...
@@ -1035,20 +1035,20 @@ class DeviceConfig:
...
@@ -1035,20 +1035,20 @@ class DeviceConfig:
def
__init__
(
self
,
device
:
str
=
"auto"
)
->
None
:
def
__init__
(
self
,
device
:
str
=
"auto"
)
->
None
:
if
device
==
"auto"
:
if
device
==
"auto"
:
# Automated device type detection
# Automated device type detection
if
is_neuron
():
if
current_platform
.
is_cuda_alike
():
self
.
device_type
=
"cuda"
elif
is_neuron
():
self
.
device_type
=
"neuron"
self
.
device_type
=
"neuron"
elif
is_openvino
():
elif
is_openvino
():
self
.
device_type
=
"openvino"
self
.
device_type
=
"openvino"
elif
current_platform
.
is_tpu
():
elif
current_platform
.
is_tpu
():
self
.
device_type
=
"tpu"
self
.
device_type
=
"tpu"
elif
is_cpu
():
elif
current_platform
.
is_cpu
():
self
.
device_type
=
"cpu"
self
.
device_type
=
"cpu"
elif
is_xpu
():
elif
is_xpu
():
self
.
device_type
=
"xpu"
self
.
device_type
=
"xpu"
else
:
else
:
# We don't call torch.cuda.is_available() here to
raise
RuntimeError
(
"Failed to infer device type"
)
# avoid initializing CUDA before workers are forked
self
.
device_type
=
"cuda"
else
:
else
:
# Device type is assigned explicitly
# Device type is assigned explicitly
self
.
device_type
=
device
self
.
device_type
=
device
...
...
vllm/distributed/parallel_state.py
View file @
6ffa3f31
...
@@ -35,6 +35,7 @@ from torch.distributed import Backend, ProcessGroup
...
@@ -35,6 +35,7 @@ from torch.distributed import Backend, ProcessGroup
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
@
dataclass
@
dataclass
...
@@ -191,7 +192,7 @@ class GroupCoordinator:
...
@@ -191,7 +192,7 @@ class GroupCoordinator:
assert
self
.
cpu_group
is
not
None
assert
self
.
cpu_group
is
not
None
assert
self
.
device_group
is
not
None
assert
self
.
device_group
is
not
None
if
torch
.
cuda
.
is_availabl
e
():
if
current_platform
.
is_cuda_alik
e
():
self
.
device
=
torch
.
device
(
f
"cuda:
{
local_rank
}
"
)
self
.
device
=
torch
.
device
(
f
"cuda:
{
local_rank
}
"
)
else
:
else
:
self
.
device
=
torch
.
device
(
"cpu"
)
self
.
device
=
torch
.
device
(
"cpu"
)
...
...
vllm/envs.py
View file @
6ffa3f31
...
@@ -60,6 +60,7 @@ if TYPE_CHECKING:
...
@@ -60,6 +60,7 @@ if TYPE_CHECKING:
VLLM_RPC_GET_DATA_TIMEOUT_MS
:
int
=
5000
VLLM_RPC_GET_DATA_TIMEOUT_MS
:
int
=
5000
VLLM_PLUGINS
:
Optional
[
List
[
str
]]
=
None
VLLM_PLUGINS
:
Optional
[
List
[
str
]]
=
None
VLLM_TORCH_PROFILER_DIR
:
Optional
[
str
]
=
None
VLLM_TORCH_PROFILER_DIR
:
Optional
[
str
]
=
None
VLLM_USE_TRITON_AWQ
:
bool
=
False
VLLM_ALLOW_RUNTIME_LORA_UPDATING
:
bool
=
False
VLLM_ALLOW_RUNTIME_LORA_UPDATING
:
bool
=
False
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
View file @
6ffa3f31
...
@@ -116,10 +116,10 @@ class CompressedTensorsConfig(QuantizationConfig):
...
@@ -116,10 +116,10 @@ class CompressedTensorsConfig(QuantizationConfig):
def
_check_scheme_supported
(
self
,
def
_check_scheme_supported
(
self
,
min_capability
:
int
,
min_capability
:
int
,
error
:
bool
=
True
)
->
bool
:
error
:
bool
=
True
)
->
bool
:
capability
=
current_platform
.
get_device_capability
()
# type: ignore
capability
_tuple
=
current_platform
.
get_device_capability
()
if
capability
is
not
None
:
if
capability
_tuple
is
not
None
:
capability
=
capability
[
0
]
*
10
+
capability
[
1
]
capability
=
capability
_tuple
.
to_int
()
supported
=
capability
>=
min_capability
supported
=
capability
>=
min_capability
if
error
and
not
supported
:
if
error
and
not
supported
:
raise
RuntimeError
(
raise
RuntimeError
(
...
...
vllm/model_executor/layers/quantization/fbgemm_fp8.py
View file @
6ffa3f31
...
@@ -32,9 +32,7 @@ class FBGEMMFp8Config(QuantizationConfig):
...
@@ -32,9 +32,7 @@ class FBGEMMFp8Config(QuantizationConfig):
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
# kernel for fast weight-only FP8 quantization
# kernel for fast weight-only FP8 quantization
capability
=
current_platform
.
get_device_capability
()
self
.
use_marlin
=
not
current_platform
.
has_device_capability
(
89
)
capability
=
capability
[
0
]
*
10
+
capability
[
1
]
self
.
use_marlin
=
capability
<
89
@
classmethod
@
classmethod
def
get_name
(
cls
)
->
str
:
def
get_name
(
cls
)
->
str
:
...
...
vllm/model_executor/layers/quantization/fp8.py
View file @
6ffa3f31
...
@@ -120,9 +120,8 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -120,9 +120,8 @@ class Fp8LinearMethod(LinearMethodBase):
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
# kernel for fast weight-only FP8 quantization
# kernel for fast weight-only FP8 quantization
capability
=
current_platform
.
get_device_capability
()
self
.
use_marlin
=
(
not
current_platform
.
has_device_capability
(
89
)
capability
=
capability
[
0
]
*
10
+
capability
[
1
]
or
envs
.
VLLM_TEST_FORCE_FP8_MARLIN
)
self
.
use_marlin
=
capability
<
89
or
envs
.
VLLM_TEST_FORCE_FP8_MARLIN
# Disable marlin for rocm
# Disable marlin for rocm
if
is_hip
():
if
is_hip
():
self
.
use_marlin
=
False
self
.
use_marlin
=
False
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment