Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a2217966
Commit
a2217966
authored
Nov 20, 2024
by
王敏
Browse files
Merge remote-tracking branch 'origin/v0.6.2-dev' into v0.6.2-dev
parents
93089fb2
1a493a24
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
168 additions
and
130 deletions
+168
-130
CMakeLists.txt
CMakeLists.txt
+1
-1
csrc/attention/attention_kernels_opt_tc.cu
csrc/attention/attention_kernels_opt_tc.cu
+4
-4
requirements-rocm.txt
requirements-rocm.txt
+1
-1
tests/entrypoints/openai/test_oot_registration.py
tests/entrypoints/openai/test_oot_registration.py
+27
-11
tests/kernels/test_flash_attn.py
tests/kernels/test_flash_attn.py
+130
-110
tests/kernels/test_prefix_prefill.py
tests/kernels/test_prefix_prefill.py
+5
-3
No files found.
CMakeLists.txt
View file @
a2217966
...
@@ -41,7 +41,7 @@ set(PYTHON_SUPPORTED_VERSIONS "3.8" "3.9" "3.10" "3.11" "3.12")
...
@@ -41,7 +41,7 @@ set(PYTHON_SUPPORTED_VERSIONS "3.8" "3.9" "3.10" "3.11" "3.12")
set
(
CUDA_SUPPORTED_ARCHS
"7.0;7.5;8.0;8.6;8.9;9.0"
)
set
(
CUDA_SUPPORTED_ARCHS
"7.0;7.5;8.0;8.6;8.9;9.0"
)
# Supported AMD GPU architectures.
# Supported AMD GPU architectures.
set
(
HIP_SUPPORTED_ARCHS
"gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx926;gfx928"
)
set
(
HIP_SUPPORTED_ARCHS
"gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx926;gfx928
;gfx936
"
)
#
#
# Supported/expected torch versions for CUDA/ROCm.
# Supported/expected torch versions for CUDA/ROCm.
...
...
csrc/attention/attention_kernels_opt_tc.cu
View file @
a2217966
...
@@ -567,7 +567,7 @@ __global__ void paged_attention_v1_kernel_TC(
...
@@ -567,7 +567,7 @@ __global__ void paged_attention_v1_kernel_TC(
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
,
const
int
*
__restrict__
attn_masks
=
nullptr
,
const
int
attn_masks_stride
=
0
)
{
const
int
*
__restrict__
attn_masks
=
nullptr
,
const
int
attn_masks_stride
=
0
)
{
#ifdef
__gfx928__
#if
def
ined(__gfx936__) || defined(
__gfx928__
)
paged_attention_kernel_TC
<
scalar_t
,
cache_t
,
HEAD_SIZE
,
BLOCK_SIZE
,
NUM_THREADS
,
paged_attention_kernel_TC
<
scalar_t
,
cache_t
,
HEAD_SIZE
,
BLOCK_SIZE
,
NUM_THREADS
,
KV_DTYPE
,
IS_BLOCK_SPARSE
,
REUSE_KV_TIMES
,
use_vmac
>
(
KV_DTYPE
,
IS_BLOCK_SPARSE
,
REUSE_KV_TIMES
,
use_vmac
>
(
/* exp_sums */
nullptr
,
/* max_logits */
nullptr
,
out
,
q
,
k_cache
,
/* exp_sums */
nullptr
,
/* max_logits */
nullptr
,
out
,
q
,
k_cache
,
...
@@ -607,7 +607,7 @@ __global__ __launch_bounds__(256, 1) void paged_attention_v2_kernel_TC(
...
@@ -607,7 +607,7 @@ __global__ __launch_bounds__(256, 1) void paged_attention_v2_kernel_TC(
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
,
const
int
*
__restrict__
attn_masks
=
nullptr
,
const
int
attn_masks_stride
=
0
)
{
const
int
*
__restrict__
attn_masks
=
nullptr
,
const
int
attn_masks_stride
=
0
)
{
#ifdef
__gfx928__
#if
def
ined(__gfx936__) || defined(
__gfx928__
)
paged_attention_kernel_TC
<
scalar_t
,
cache_t
,
HEAD_SIZE
,
BLOCK_SIZE
,
NUM_THREADS
,
paged_attention_kernel_TC
<
scalar_t
,
cache_t
,
HEAD_SIZE
,
BLOCK_SIZE
,
NUM_THREADS
,
KV_DTYPE
,
IS_BLOCK_SPARSE
,
REUSE_KV_TIMES
,
use_vmac
,
KV_DTYPE
,
IS_BLOCK_SPARSE
,
REUSE_KV_TIMES
,
use_vmac
,
PARTITION_SIZE
>
(
PARTITION_SIZE
>
(
...
@@ -952,7 +952,7 @@ void paged_attention_v1_opt_tc(
...
@@ -952,7 +952,7 @@ void paged_attention_v1_opt_tc(
const
int64_t
attn_masks_stride
)
{
const
int64_t
attn_masks_stride
)
{
const
bool
is_block_sparse
=
(
blocksparse_vert_stride
>
1
);
const
bool
is_block_sparse
=
(
blocksparse_vert_stride
>
1
);
if
(
kv_cache_dtype
!=
"auto"
||
query
.
dtype
()
==
at
::
ScalarType
::
Float
||
is_block_sparse
||
if
(
kv_cache_dtype
!=
"auto"
||
query
.
dtype
()
==
at
::
ScalarType
::
Float
||
is_block_sparse
||
block_size
!=
16
||
query
.
size
(
2
)
!=
128
||
get_device_name
()
!=
"gfx928"
){
block_size
!=
16
||
query
.
size
(
2
)
!=
128
||
(
get_device_name
()
!=
"gfx928"
&&
get_device_name
()
!=
"gfx936"
)
){
paged_attention_v1_opt
(
out
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
paged_attention_v1_opt
(
out
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
...
@@ -1182,7 +1182,7 @@ void paged_attention_v2_opt_tc(
...
@@ -1182,7 +1182,7 @@ void paged_attention_v2_opt_tc(
const
int64_t
attn_masks_stride
)
{
const
int64_t
attn_masks_stride
)
{
const
bool
is_block_sparse
=
(
blocksparse_vert_stride
>
1
);
const
bool
is_block_sparse
=
(
blocksparse_vert_stride
>
1
);
if
(
kv_cache_dtype
!=
"auto"
||
query
.
dtype
()
==
at
::
ScalarType
::
Float
||
is_block_sparse
||
if
(
kv_cache_dtype
!=
"auto"
||
query
.
dtype
()
==
at
::
ScalarType
::
Float
||
is_block_sparse
||
block_size
!=
16
||
query
.
size
(
2
)
!=
128
||
get_device_name
()
!=
"gfx928"
){
block_size
!=
16
||
query
.
size
(
2
)
!=
128
||
(
get_device_name
()
!=
"gfx928"
&&
get_device_name
()
!=
"gfx936"
)
){
paged_attention_v2_opt
(
out
,
exp_sums
,
max_logits
,
tmp_out
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
paged_attention_v2_opt
(
out
,
exp_sums
,
max_logits
,
tmp_out
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
...
...
requirements-rocm.txt
View file @
a2217966
...
@@ -11,7 +11,7 @@ pytest-asyncio
...
@@ -11,7 +11,7 @@ pytest-asyncio
tensorizer>=2.9.0
tensorizer>=2.9.0
setuptools_scm>=8
setuptools_scm>=8
torch == 2.4.
0
torch == 2.4.
1
triton == 3.0.0
triton == 3.0.0
flash_attn == 2.6.1
flash_attn == 2.6.1
lmslim == 0.2.0
lmslim == 0.2.0
\ No newline at end of file
tests/entrypoints/openai/test_oot_registration.py
View file @
a2217966
from
...utils
import
VLLM_PATH
,
RemoteOpenAIServer
from
...utils
import
VLLM_PATH
,
RemoteOpenAIServer
import
vllm.envs
as
envs
chatml_jinja_path
=
VLLM_PATH
/
"examples/template_chatml.jinja"
chatml_jinja_path
=
VLLM_PATH
/
"examples/template_chatml.jinja"
assert
chatml_jinja_path
.
exists
()
assert
chatml_jinja_path
.
exists
()
...
@@ -6,6 +7,7 @@ assert chatml_jinja_path.exists()
...
@@ -6,6 +7,7 @@ assert chatml_jinja_path.exists()
def
run_and_test_dummy_opt_api_server
(
model
,
tp
=
1
):
def
run_and_test_dummy_opt_api_server
(
model
,
tp
=
1
):
# the model is registered through the plugin
# the model is registered through the plugin
if
envs
.
VLLM_USE_TRITON_FLASH_ATTN
:
server_args
=
[
server_args
=
[
"--gpu-memory-utilization"
,
"--gpu-memory-utilization"
,
"0.10"
,
"0.10"
,
...
@@ -18,6 +20,19 @@ def run_and_test_dummy_opt_api_server(model, tp=1):
...
@@ -18,6 +20,19 @@ def run_and_test_dummy_opt_api_server(model, tp=1):
"-tp"
,
"-tp"
,
f
"
{
tp
}
"
,
f
"
{
tp
}
"
,
]
]
else
:
server_args
=
[
"--gpu-memory-utilization"
,
"0.10"
,
"--dtype"
,
"float16"
,
"--chat-template"
,
str
(
chatml_jinja_path
),
"--load-format"
,
"dummy"
,
"-tp"
,
f
"
{
tp
}
"
,
]
with
RemoteOpenAIServer
(
model
,
server_args
)
as
server
:
with
RemoteOpenAIServer
(
model
,
server_args
)
as
server
:
client
=
server
.
get_client
()
client
=
server
.
get_client
()
completion
=
client
.
chat
.
completions
.
create
(
completion
=
client
.
chat
.
completions
.
create
(
...
@@ -39,4 +54,5 @@ def run_and_test_dummy_opt_api_server(model, tp=1):
...
@@ -39,4 +54,5 @@ def run_and_test_dummy_opt_api_server(model, tp=1):
def
test_oot_registration_for_api_server
(
dummy_opt_path
:
str
):
def
test_oot_registration_for_api_server
(
dummy_opt_path
:
str
):
dummy_opt_path
=
"facebook/opt-125m"
run_and_test_dummy_opt_api_server
(
dummy_opt_path
)
run_and_test_dummy_opt_api_server
(
dummy_opt_path
)
tests/kernels/test_flash_attn.py
View file @
a2217966
...
@@ -3,7 +3,11 @@ from typing import List, Optional, Tuple
...
@@ -3,7 +3,11 @@ from typing import List, Optional, Tuple
import
pytest
import
pytest
import
torch
import
torch
import
vllm.attention.backends.flash_attn
# noqa: F401
from
vllm.utils
import
is_hip
if
is_hip
():
import
flash_attn
else
:
import
vllm.attention.backends.flash_attn
# noqa: F401
from
tests.kernels.utils
import
opcheck
from
tests.kernels.utils
import
opcheck
from
vllm.utils
import
seed_everything
from
vllm.utils
import
seed_everything
...
@@ -70,16 +74,16 @@ def ref_paged_attn(
...
@@ -70,16 +74,16 @@ def ref_paged_attn(
return
torch
.
cat
(
outputs
,
dim
=
0
)
return
torch
.
cat
(
outputs
,
dim
=
0
)
if
not
is_hip
():
@
pytest
.
mark
.
parametrize
(
"kv_lens"
,
[[
1328
,
18
,
463
],
[
1
,
54
,
293
,
70
]])
@
pytest
.
mark
.
parametrize
(
"kv_lens"
,
[[
1328
,
18
,
463
],
[
1
,
54
,
293
,
70
]])
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
@
pytest
.
mark
.
parametrize
(
"block_size"
,
BLOCK_SIZES
)
@
pytest
.
mark
.
parametrize
(
"block_size"
,
BLOCK_SIZES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"soft_cap"
,
[
None
,
10.0
,
50.0
])
@
pytest
.
mark
.
parametrize
(
"soft_cap"
,
[
None
,
10.0
,
50.0
])
@
pytest
.
mark
.
parametrize
(
"num_blocks"
,
NUM_BLOCKS
)
@
pytest
.
mark
.
parametrize
(
"num_blocks"
,
NUM_BLOCKS
)
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test_flash_attn_with_paged_kv
(
def
test_flash_attn_with_paged_kv
(
kv_lens
:
List
[
int
],
kv_lens
:
List
[
int
],
num_heads
:
Tuple
[
int
,
int
],
num_heads
:
Tuple
[
int
,
int
],
head_size
:
int
,
head_size
:
int
,
...
@@ -87,7 +91,7 @@ def test_flash_attn_with_paged_kv(
...
@@ -87,7 +91,7 @@ def test_flash_attn_with_paged_kv(
block_size
:
int
,
block_size
:
int
,
soft_cap
:
Optional
[
float
],
soft_cap
:
Optional
[
float
],
num_blocks
:
int
,
num_blocks
:
int
,
)
->
None
:
)
->
None
:
torch
.
set_default_device
(
"cuda"
)
torch
.
set_default_device
(
"cuda"
)
seed_everything
(
0
)
seed_everything
(
0
)
num_seqs
=
len
(
kv_lens
)
num_seqs
=
len
(
kv_lens
)
...
@@ -212,7 +216,22 @@ def test_varlen_with_paged_kv(
...
@@ -212,7 +216,22 @@ def test_varlen_with_paged_kv(
num_blocks
,
num_blocks
,
(
num_seqs
,
max_num_blocks_per_seq
),
(
num_seqs
,
max_num_blocks_per_seq
),
dtype
=
torch
.
int32
)
dtype
=
torch
.
int32
)
if
is_hip
():
output
=
flash_attn
.
flash_attn_varlen_func
(
q
=
query
,
k
=
key_cache
,
v
=
value_cache
,
cu_seqlens_q
=
cu_query_lens
,
cu_seqlens_k
=
cu_kv_lens
,
max_seqlen_q
=
max_query_len
,
max_seqlen_k
=
max_kv_len
,
softmax_scale
=
scale
,
causal
=
True
,
window_size
=
window_size
,
block_table
=
block_tables
,
softcap
=
soft_cap
if
soft_cap
is
not
None
else
0
,
)
else
:
output
=
torch
.
ops
.
vllm
.
flash_attn_varlen_func
(
output
=
torch
.
ops
.
vllm
.
flash_attn_varlen_func
(
q
=
query
,
q
=
query
,
k
=
key_cache
,
k
=
key_cache
,
...
@@ -233,6 +252,7 @@ def test_varlen_with_paged_kv(
...
@@ -233,6 +252,7 @@ def test_varlen_with_paged_kv(
else
:
else
:
test_utils
=
[
"test_faketensor"
]
test_utils
=
[
"test_faketensor"
]
if
not
is_hip
():
opcheck
(
torch
.
ops
.
vllm
.
flash_attn_varlen_func
,
opcheck
(
torch
.
ops
.
vllm
.
flash_attn_varlen_func
,
args
=
tuple
(),
args
=
tuple
(),
kwargs
=
dict
(
kwargs
=
dict
(
...
...
tests/kernels/test_prefix_prefill.py
View file @
a2217966
...
@@ -4,14 +4,16 @@ import time
...
@@ -4,14 +4,16 @@ import time
import
pytest
import
pytest
import
torch
import
torch
from
xformers
import
ops
as
xops
from
xformers.ops.fmha.attn_bias
import
BlockDiagonalCausalFromBottomRightMask
from
vllm.attention.backends.xformers
import
_make_alibi_bias
from
vllm.attention.ops.prefix_prefill
import
context_attention_fwd
from
vllm.attention.ops.prefix_prefill
import
context_attention_fwd
from
vllm.utils
import
is_hip
from
vllm.utils
import
is_hip
from
vllm.utils
import
STR_DTYPE_TO_TORCH_DTYPE
,
seed_everything
from
vllm.utils
import
STR_DTYPE_TO_TORCH_DTYPE
,
seed_everything
if
not
is_hip
():
from
xformers
import
ops
as
xops
from
xformers.ops.fmha.attn_bias
import
BlockDiagonalCausalFromBottomRightMask
from
vllm.attention.backends.xformers
import
_make_alibi_bias
NUM_HEADS
=
[
64
]
NUM_HEADS
=
[
64
]
NUM_QUERIES_PER_KV
=
[
1
,
8
,
64
]
NUM_QUERIES_PER_KV
=
[
1
,
8
,
64
]
HEAD_SIZES
=
[
128
,
96
,
24
]
HEAD_SIZES
=
[
128
,
96
,
24
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment