Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
978b45f3
Unverified
Commit
978b45f3
authored
Jan 23, 2025
by
Lucas Wilkinson
Committed by
GitHub
Jan 23, 2025
Browse files
[Kernel] Flash Attention 3 Support (#12093)
Signed-off-by:
Lucas Wilkinson
<
lwilkinson@neuralmagic.com
>
parent
c5b4b11d
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
150 additions
and
82 deletions
+150
-82
CMakeLists.txt
CMakeLists.txt
+20
-25
setup.py
setup.py
+8
-4
tests/kernels/test_cascade_flash_attn.py
tests/kernels/test_cascade_flash_attn.py
+14
-10
tests/kernels/test_flash_attn.py
tests/kernels/test_flash_attn.py
+18
-4
vllm/attention/backends/flash_attn.py
vllm/attention/backends/flash_attn.py
+24
-3
vllm/envs.py
vllm/envs.py
+12
-0
vllm/v1/attention/backends/flash_attn.py
vllm/v1/attention/backends/flash_attn.py
+33
-11
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+21
-25
No files found.
CMakeLists.txt
View file @
978b45f3
...
@@ -24,9 +24,6 @@ include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake)
...
@@ -24,9 +24,6 @@ include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake)
# Suppress potential warnings about unused manually-specified variables
# Suppress potential warnings about unused manually-specified variables
set
(
ignoreMe
"
${
VLLM_PYTHON_PATH
}
"
)
set
(
ignoreMe
"
${
VLLM_PYTHON_PATH
}
"
)
# Prevent installation of dependencies (cutlass) by default.
install
(
CODE
"set(CMAKE_INSTALL_LOCAL_ONLY TRUE)"
ALL_COMPONENTS
)
#
#
# Supported python versions. These versions will be searched in order, the
# Supported python versions. These versions will be searched in order, the
# first match will be selected. These should be kept in sync with setup.py.
# first match will be selected. These should be kept in sync with setup.py.
...
@@ -535,7 +532,7 @@ if(VLLM_GPU_LANG STREQUAL "HIP")
...
@@ -535,7 +532,7 @@ if(VLLM_GPU_LANG STREQUAL "HIP")
endif
()
endif
()
# vllm-flash-attn currently only supported on CUDA
# vllm-flash-attn currently only supported on CUDA
if
(
NOT VLLM_
TARGET_DEVICE
STREQUAL
"
cuda
"
)
if
(
NOT VLLM_
GPU_LANG
STREQUAL
"
CUDA
"
)
return
()
return
()
endif
()
endif
()
...
@@ -558,7 +555,7 @@ endif()
...
@@ -558,7 +555,7 @@ endif()
# They should be identical but if they aren't, this is a massive footgun.
# They should be identical but if they aren't, this is a massive footgun.
#
#
# The vllm-flash-attn install rules are nested under vllm to make sure the library gets installed in the correct place.
# The vllm-flash-attn install rules are nested under vllm to make sure the library gets installed in the correct place.
# To only install vllm-flash-attn, use --component vllm_f
lash_attn_c
.
# To only install vllm-flash-attn, use --component
_
vllm_f
a2_C (for FA2) or --component _vllm_fa3_C (for FA3)
.
# If no component is specified, vllm-flash-attn is still installed.
# If no component is specified, vllm-flash-attn is still installed.
# If VLLM_FLASH_ATTN_SRC_DIR is set, vllm-flash-attn is installed from that directory instead of downloading.
# If VLLM_FLASH_ATTN_SRC_DIR is set, vllm-flash-attn is installed from that directory instead of downloading.
...
@@ -570,42 +567,40 @@ if (DEFINED ENV{VLLM_FLASH_ATTN_SRC_DIR})
...
@@ -570,42 +567,40 @@ if (DEFINED ENV{VLLM_FLASH_ATTN_SRC_DIR})
endif
()
endif
()
if
(
VLLM_FLASH_ATTN_SRC_DIR
)
if
(
VLLM_FLASH_ATTN_SRC_DIR
)
FetchContent_Declare
(
vllm-flash-attn SOURCE_DIR
${
VLLM_FLASH_ATTN_SRC_DIR
}
)
FetchContent_Declare
(
vllm-flash-attn SOURCE_DIR
${
VLLM_FLASH_ATTN_SRC_DIR
}
BINARY_DIR
${
CMAKE_BINARY_DIR
}
/vllm-flash-attn
)
else
()
else
()
FetchContent_Declare
(
FetchContent_Declare
(
vllm-flash-attn
vllm-flash-attn
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
GIT_TAG 9
6266b1111111f3d11aabefaf3bacbab6a89d03c
GIT_TAG 9
0eacc1af2a7c3de62ea249e929ed5faccf38954
GIT_PROGRESS TRUE
GIT_PROGRESS TRUE
# Don't share the vllm-flash-attn build between build types
# Don't share the vllm-flash-attn build between build types
BINARY_DIR
${
CMAKE_BINARY_DIR
}
/vllm-flash-attn
BINARY_DIR
${
CMAKE_BINARY_DIR
}
/vllm-flash-attn
)
)
endif
()
endif
()
# Set the parent build flag so that the vllm-flash-attn library does not redo compile flag and arch initialization.
set
(
VLLM_PARENT_BUILD ON
)
# Ensure the vllm/vllm_flash_attn directory exists before installation
install
(
CODE
"file(MAKE_DIRECTORY
\"\$
{CMAKE_INSTALL_PREFIX}/vllm/vllm_flash_attn
\"
)"
COMPONENT vllm_flash_attn_c
)
# Make sure vllm-flash-attn install rules are nested under vllm/
install
(
CODE
"set(CMAKE_INSTALL_LOCAL_ONLY FALSE)"
COMPONENT vllm_flash_attn_c
)
install
(
CODE
"set(OLD_CMAKE_INSTALL_PREFIX
\"\$
{CMAKE_INSTALL_PREFIX}
\"
)"
COMPONENT vllm_flash_attn_c
)
install
(
CODE
"set(CMAKE_INSTALL_PREFIX
\"\$
{CMAKE_INSTALL_PREFIX}/vllm/
\"
)"
COMPONENT vllm_flash_attn_c
)
# Fetch the vllm-flash-attn library
# Fetch the vllm-flash-attn library
FetchContent_MakeAvailable
(
vllm-flash-attn
)
FetchContent_MakeAvailable
(
vllm-flash-attn
)
message
(
STATUS
"vllm-flash-attn is available at
${
vllm-flash-attn_SOURCE_DIR
}
"
)
message
(
STATUS
"vllm-flash-attn is available at
${
vllm-flash-attn_SOURCE_DIR
}
"
)
# Restore the install prefix
# Copy over the vllm-flash-attn python files (duplicated for fa2 and fa3, in
install
(
CODE
"set(CMAKE_INSTALL_PREFIX
\"\$
{OLD_CMAKE_INSTALL_PREFIX}
\"
)"
COMPONENT vllm_flash_attn_c
)
# case only one is built, in the case both are built redundant work is done)
install
(
CODE
"set(CMAKE_INSTALL_LOCAL_ONLY TRUE)"
COMPONENT vllm_flash_attn_c
)
install
(
DIRECTORY
${
vllm-flash-attn_SOURCE_DIR
}
/vllm_flash_attn/
DESTINATION vllm_flash_attn
COMPONENT _vllm_fa2_C
FILES_MATCHING PATTERN
"*.py"
)
# Copy over the vllm-flash-attn python files
install
(
install
(
DIRECTORY
${
vllm-flash-attn_SOURCE_DIR
}
/vllm_flash_attn/
DIRECTORY
${
vllm-flash-attn_SOURCE_DIR
}
/vllm_flash_attn/
DESTINATION
vllm/
vllm_flash_attn
DESTINATION vllm_flash_attn
COMPONENT vllm_f
lash_attn_c
COMPONENT
_
vllm_f
a3_C
FILES_MATCHING PATTERN
"*.py"
FILES_MATCHING PATTERN
"*.py"
)
)
...
...
setup.py
View file @
978b45f3
...
@@ -228,8 +228,11 @@ class cmake_build_ext(build_ext):
...
@@ -228,8 +228,11 @@ class cmake_build_ext(build_ext):
# CMake appends the extension prefix to the install path,
# CMake appends the extension prefix to the install path,
# and outdir already contains that prefix, so we need to remove it.
# and outdir already contains that prefix, so we need to remove it.
# We assume only the final component of extension prefix is added by
# CMake, this is currently true for current extensions but may not
# always be the case.
prefix
=
outdir
prefix
=
outdir
for
i
in
range
(
ext
.
name
.
count
(
'.'
))
:
if
'.'
in
ext
.
name
:
prefix
=
prefix
.
parent
prefix
=
prefix
.
parent
# prefix here should actually be the same for all components
# prefix here should actually be the same for all components
...
@@ -298,7 +301,8 @@ class repackage_wheel(build_ext):
...
@@ -298,7 +301,8 @@ class repackage_wheel(build_ext):
files_to_copy
=
[
files_to_copy
=
[
"vllm/_C.abi3.so"
,
"vllm/_C.abi3.so"
,
"vllm/_moe_C.abi3.so"
,
"vllm/_moe_C.abi3.so"
,
"vllm/vllm_flash_attn/vllm_flash_attn_c.abi3.so"
,
"vllm/vllm_flash_attn/_vllm_fa2_C.abi3.so"
,
"vllm/vllm_flash_attn/_vllm_fa3_C.abi3.so"
,
"vllm/vllm_flash_attn/flash_attn_interface.py"
,
"vllm/vllm_flash_attn/flash_attn_interface.py"
,
"vllm/vllm_flash_attn/__init__.py"
,
"vllm/vllm_flash_attn/__init__.py"
,
"vllm/cumem_allocator.abi3.so"
,
"vllm/cumem_allocator.abi3.so"
,
...
@@ -593,8 +597,8 @@ if _is_hip():
...
@@ -593,8 +597,8 @@ if _is_hip():
ext_modules
.
append
(
CMakeExtension
(
name
=
"vllm._rocm_C"
))
ext_modules
.
append
(
CMakeExtension
(
name
=
"vllm._rocm_C"
))
if
_is_cuda
():
if
_is_cuda
():
ext_modules
.
append
(
ext_modules
.
append
(
CMakeExtension
(
name
=
"vllm.vllm_flash_attn._vllm_fa2_C"
))
CMakeExtension
(
name
=
"vllm.vllm_flash_attn.vllm_f
lash_attn_c
"
))
ext_modules
.
append
(
CMakeExtension
(
name
=
"vllm.vllm_flash_attn.
_
vllm_f
a3_C
"
))
ext_modules
.
append
(
CMakeExtension
(
name
=
"vllm.cumem_allocator"
))
ext_modules
.
append
(
CMakeExtension
(
name
=
"vllm.cumem_allocator"
))
if
_build_custom_ops
():
if
_build_custom_ops
():
...
...
tests/kernels/test_cascade_flash_attn.py
View file @
978b45f3
...
@@ -78,6 +78,7 @@ CASES = [
...
@@ -78,6 +78,7 @@ CASES = [
@
pytest
.
mark
.
parametrize
(
"block_size"
,
BLOCK_SIZES
)
@
pytest
.
mark
.
parametrize
(
"block_size"
,
BLOCK_SIZES
)
@
pytest
.
mark
.
parametrize
(
"soft_cap"
,
[
None
,
50
])
@
pytest
.
mark
.
parametrize
(
"soft_cap"
,
[
None
,
50
])
@
pytest
.
mark
.
parametrize
(
"num_blocks"
,
[
2048
])
@
pytest
.
mark
.
parametrize
(
"num_blocks"
,
[
2048
])
@
pytest
.
mark
.
parametrize
(
"fa_version"
,
[
2
,
3
])
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test_cascade
(
def
test_cascade
(
seq_lens_and_common_prefix
:
Tuple
[
List
[
Tuple
[
int
,
int
]],
int
],
seq_lens_and_common_prefix
:
Tuple
[
List
[
Tuple
[
int
,
int
]],
int
],
...
@@ -87,8 +88,14 @@ def test_cascade(
...
@@ -87,8 +88,14 @@ def test_cascade(
block_size
:
int
,
block_size
:
int
,
soft_cap
:
Optional
[
float
],
soft_cap
:
Optional
[
float
],
num_blocks
:
int
,
num_blocks
:
int
,
fa_version
:
int
,
)
->
None
:
)
->
None
:
torch
.
set_default_device
(
"cuda"
)
torch
.
set_default_device
(
"cuda"
)
if
fa_version
==
3
and
(
torch
.
cuda
.
get_device_capability
()
==
(
8
,
6
)
or
torch
.
cuda
.
get_device_capability
()
==
(
8
,
9
)):
pytest
.
skip
(
"Flash attention version 3 fails on 8.6 and 8.9 due to "
"insufficient shared memory for some shapes"
)
current_platform
.
seed_everything
(
0
)
current_platform
.
seed_everything
(
0
)
window_size
=
(
-
1
,
-
1
)
window_size
=
(
-
1
,
-
1
)
...
@@ -118,9 +125,7 @@ def test_cascade(
...
@@ -118,9 +125,7 @@ def test_cascade(
cu_query_lens
=
torch
.
tensor
([
0
]
+
query_lens
,
cu_query_lens
=
torch
.
tensor
([
0
]
+
query_lens
,
dtype
=
torch
.
int32
).
cumsum
(
dim
=
0
,
dtype
=
torch
.
int32
).
cumsum
(
dim
=
0
,
dtype
=
torch
.
int32
)
dtype
=
torch
.
int32
)
cu_kv_lens
=
torch
.
tensor
([
0
]
+
kv_lens
,
kv_lens_tensor
=
torch
.
tensor
(
kv_lens
,
dtype
=
torch
.
int32
)
dtype
=
torch
.
int32
).
cumsum
(
dim
=
0
,
dtype
=
torch
.
int32
)
max_num_blocks_per_seq
=
(
max_kv_len
+
block_size
-
1
)
//
block_size
max_num_blocks_per_seq
=
(
max_kv_len
+
block_size
-
1
)
//
block_size
block_tables
=
torch
.
randint
(
0
,
block_tables
=
torch
.
randint
(
0
,
num_blocks
,
num_blocks
,
...
@@ -140,7 +145,7 @@ def test_cascade(
...
@@ -140,7 +145,7 @@ def test_cascade(
k
=
key_cache
,
k
=
key_cache
,
v
=
value_cache
,
v
=
value_cache
,
cu_seqlens_q
=
cu_query_lens
,
cu_seqlens_q
=
cu_query_lens
,
cu_seqlens_k
=
cu_kv_lens
,
seqused_k
=
kv_lens_tensor
,
max_seqlen_q
=
max_query_len
,
max_seqlen_q
=
max_query_len
,
max_seqlen_k
=
max_kv_len
,
max_seqlen_k
=
max_kv_len
,
softmax_scale
=
scale
,
softmax_scale
=
scale
,
...
@@ -154,10 +159,8 @@ def test_cascade(
...
@@ -154,10 +159,8 @@ def test_cascade(
assert
all
(
common_prefix_len
<
kv_len
for
kv_len
in
kv_lens
)
assert
all
(
common_prefix_len
<
kv_len
for
kv_len
in
kv_lens
)
cu_prefix_query_lens
=
torch
.
tensor
([
0
,
total_num_query_tokens
],
cu_prefix_query_lens
=
torch
.
tensor
([
0
,
total_num_query_tokens
],
dtype
=
torch
.
int32
)
dtype
=
torch
.
int32
)
cu_prefix_kv_lens
=
torch
.
tensor
([
0
,
common_prefix_len
],
dtype
=
torch
.
int32
)
prefix_kv_lens
=
torch
.
tensor
([
common_prefix_len
],
dtype
=
torch
.
int32
)
cu_suffix_kv_lens
=
(
suffix_kv_lens
=
kv_lens_tensor
-
common_prefix_len
cu_kv_lens
-
torch
.
arange
(
num_seqs
+
1
,
dtype
=
torch
.
int32
)
*
common_prefix_len
)
output
=
torch
.
empty_like
(
query
)
output
=
torch
.
empty_like
(
query
)
cascade_attention
(
cascade_attention
(
output
=
output
,
output
=
output
,
...
@@ -167,8 +170,8 @@ def test_cascade(
...
@@ -167,8 +170,8 @@ def test_cascade(
cu_query_lens
=
cu_query_lens
,
cu_query_lens
=
cu_query_lens
,
max_query_len
=
max_query_len
,
max_query_len
=
max_query_len
,
cu_prefix_query_lens
=
cu_prefix_query_lens
,
cu_prefix_query_lens
=
cu_prefix_query_lens
,
cu_
prefix_kv_lens
=
cu_
prefix_kv_lens
,
prefix_kv_lens
=
prefix_kv_lens
,
cu_
suffix_kv_lens
=
cu_
suffix_kv_lens
,
suffix_kv_lens
=
suffix_kv_lens
,
max_kv_len
=
max_kv_len
,
max_kv_len
=
max_kv_len
,
softmax_scale
=
scale
,
softmax_scale
=
scale
,
alibi_slopes
=
None
,
alibi_slopes
=
None
,
...
@@ -176,6 +179,7 @@ def test_cascade(
...
@@ -176,6 +179,7 @@ def test_cascade(
logits_soft_cap
=
soft_cap
if
soft_cap
is
not
None
else
0
,
logits_soft_cap
=
soft_cap
if
soft_cap
is
not
None
else
0
,
block_table
=
block_tables
,
block_table
=
block_tables
,
common_prefix_len
=
common_prefix_len
,
common_prefix_len
=
common_prefix_len
,
fa_version
=
fa_version
,
)
)
# Compare the results.
# Compare the results.
...
...
tests/kernels/test_flash_attn.py
View file @
978b45f3
...
@@ -80,6 +80,7 @@ def ref_paged_attn(
...
@@ -80,6 +80,7 @@ def ref_paged_attn(
@
pytest
.
mark
.
parametrize
(
"soft_cap"
,
[
None
,
10.0
,
50.0
])
@
pytest
.
mark
.
parametrize
(
"soft_cap"
,
[
None
,
10.0
,
50.0
])
@
pytest
.
mark
.
parametrize
(
"num_blocks"
,
NUM_BLOCKS
)
@
pytest
.
mark
.
parametrize
(
"num_blocks"
,
NUM_BLOCKS
)
@
pytest
.
mark
.
parametrize
(
"sliding_window"
,
[
None
,
256
])
@
pytest
.
mark
.
parametrize
(
"sliding_window"
,
[
None
,
256
])
@
pytest
.
mark
.
parametrize
(
"fa_version"
,
[
2
,
3
])
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test_flash_attn_with_paged_kv
(
def
test_flash_attn_with_paged_kv
(
use_out
:
bool
,
use_out
:
bool
,
...
@@ -91,8 +92,14 @@ def test_flash_attn_with_paged_kv(
...
@@ -91,8 +92,14 @@ def test_flash_attn_with_paged_kv(
soft_cap
:
Optional
[
float
],
soft_cap
:
Optional
[
float
],
num_blocks
:
int
,
num_blocks
:
int
,
sliding_window
:
Optional
[
int
],
sliding_window
:
Optional
[
int
],
fa_version
:
int
,
)
->
None
:
)
->
None
:
torch
.
set_default_device
(
"cuda"
)
torch
.
set_default_device
(
"cuda"
)
if
fa_version
==
3
and
(
torch
.
cuda
.
get_device_capability
()
==
(
8
,
6
)
or
torch
.
cuda
.
get_device_capability
()
==
(
8
,
9
)):
pytest
.
skip
(
"Flash attention version 3 fails on 8.6 and 8.9 due to "
"insufficient shared memory for some shapes"
)
current_platform
.
seed_everything
(
0
)
current_platform
.
seed_everything
(
0
)
num_seqs
=
len
(
kv_lens
)
num_seqs
=
len
(
kv_lens
)
num_query_heads
=
num_heads
[
0
]
num_query_heads
=
num_heads
[
0
]
...
@@ -131,6 +138,7 @@ def test_flash_attn_with_paged_kv(
...
@@ -131,6 +138,7 @@ def test_flash_attn_with_paged_kv(
cache_seqlens
=
kv_lens_tensor
,
cache_seqlens
=
kv_lens_tensor
,
softcap
=
soft_cap
if
soft_cap
is
not
None
else
0
,
softcap
=
soft_cap
if
soft_cap
is
not
None
else
0
,
window_size
=
window_size
,
window_size
=
window_size
,
fa_version
=
fa_version
,
)
)
output
=
output
if
not
use_out
else
out
output
=
output
if
not
use_out
else
out
output
=
output
.
squeeze
(
1
)
output
=
output
.
squeeze
(
1
)
...
@@ -159,6 +167,7 @@ def test_flash_attn_with_paged_kv(
...
@@ -159,6 +167,7 @@ def test_flash_attn_with_paged_kv(
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"soft_cap"
,
[
None
,
10.0
,
50.0
])
@
pytest
.
mark
.
parametrize
(
"soft_cap"
,
[
None
,
10.0
,
50.0
])
@
pytest
.
mark
.
parametrize
(
"num_blocks"
,
NUM_BLOCKS
)
@
pytest
.
mark
.
parametrize
(
"num_blocks"
,
NUM_BLOCKS
)
@
pytest
.
mark
.
parametrize
(
"fa_version"
,
[
2
,
3
])
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test_varlen_with_paged_kv
(
def
test_varlen_with_paged_kv
(
use_out
:
bool
,
use_out
:
bool
,
...
@@ -170,8 +179,14 @@ def test_varlen_with_paged_kv(
...
@@ -170,8 +179,14 @@ def test_varlen_with_paged_kv(
block_size
:
int
,
block_size
:
int
,
soft_cap
:
Optional
[
float
],
soft_cap
:
Optional
[
float
],
num_blocks
:
int
,
num_blocks
:
int
,
fa_version
:
int
,
)
->
None
:
)
->
None
:
torch
.
set_default_device
(
"cuda"
)
torch
.
set_default_device
(
"cuda"
)
if
fa_version
==
3
and
(
torch
.
cuda
.
get_device_capability
()
==
(
8
,
6
)
or
torch
.
cuda
.
get_device_capability
()
==
(
8
,
9
)):
pytest
.
skip
(
"Flash attention version 3 fails on 8.6 and 8.9 due to "
"insufficient shared memory for some shapes"
)
current_platform
.
seed_everything
(
0
)
current_platform
.
seed_everything
(
0
)
num_seqs
=
len
(
seq_lens
)
num_seqs
=
len
(
seq_lens
)
query_lens
=
[
x
[
0
]
for
x
in
seq_lens
]
query_lens
=
[
x
[
0
]
for
x
in
seq_lens
]
...
@@ -198,9 +213,7 @@ def test_varlen_with_paged_kv(
...
@@ -198,9 +213,7 @@ def test_varlen_with_paged_kv(
cu_query_lens
=
torch
.
tensor
([
0
]
+
query_lens
,
cu_query_lens
=
torch
.
tensor
([
0
]
+
query_lens
,
dtype
=
torch
.
int32
).
cumsum
(
dim
=
0
,
dtype
=
torch
.
int32
).
cumsum
(
dim
=
0
,
dtype
=
torch
.
int32
)
dtype
=
torch
.
int32
)
cu_kv_lens
=
torch
.
tensor
([
0
]
+
kv_lens
,
kv_lens
=
torch
.
tensor
(
kv_lens
,
dtype
=
torch
.
int32
)
dtype
=
torch
.
int32
).
cumsum
(
dim
=
0
,
dtype
=
torch
.
int32
)
max_num_blocks_per_seq
=
(
max_kv_len
+
block_size
-
1
)
//
block_size
max_num_blocks_per_seq
=
(
max_kv_len
+
block_size
-
1
)
//
block_size
block_tables
=
torch
.
randint
(
0
,
block_tables
=
torch
.
randint
(
0
,
...
@@ -215,7 +228,7 @@ def test_varlen_with_paged_kv(
...
@@ -215,7 +228,7 @@ def test_varlen_with_paged_kv(
v
=
value_cache
,
v
=
value_cache
,
out
=
out
,
out
=
out
,
cu_seqlens_q
=
cu_query_lens
,
cu_seqlens_q
=
cu_query_lens
,
cu_seqlens_k
=
cu_
kv_lens
,
seqused_k
=
kv_lens
,
max_seqlen_q
=
max_query_len
,
max_seqlen_q
=
max_query_len
,
max_seqlen_k
=
max_kv_len
,
max_seqlen_k
=
max_kv_len
,
softmax_scale
=
scale
,
softmax_scale
=
scale
,
...
@@ -223,6 +236,7 @@ def test_varlen_with_paged_kv(
...
@@ -223,6 +236,7 @@ def test_varlen_with_paged_kv(
window_size
=
window_size
,
window_size
=
window_size
,
block_table
=
block_tables
,
block_table
=
block_tables
,
softcap
=
soft_cap
if
soft_cap
is
not
None
else
0
,
softcap
=
soft_cap
if
soft_cap
is
not
None
else
0
,
fa_version
=
fa_version
,
)
)
output
=
output
if
not
use_out
else
out
output
=
output
if
not
use_out
else
out
...
...
vllm/attention/backends/flash_attn.py
View file @
978b45f3
...
@@ -17,7 +17,9 @@ from vllm.attention.backends.utils import (
...
@@ -17,7 +17,9 @@ from vllm.attention.backends.utils import (
compute_slot_mapping_start_idx
,
get_num_prefill_decode_query_kv_tokens
,
compute_slot_mapping_start_idx
,
get_num_prefill_decode_query_kv_tokens
,
get_seq_len_block_table_args
,
is_all_cross_attn_metadata_set
,
get_seq_len_block_table_args
,
is_all_cross_attn_metadata_set
,
is_all_encoder_attn_metadata_set
,
is_block_tables_empty
)
is_all_encoder_attn_metadata_set
,
is_block_tables_empty
)
from
vllm.envs
import
VLLM_FLASH_ATTN_VERSION
from
vllm.multimodal
import
MultiModalPlaceholderMap
from
vllm.multimodal
import
MultiModalPlaceholderMap
from
vllm.platforms
import
current_platform
from
vllm.utils
import
async_tensor_h2d
,
make_tensor_with_pad
from
vllm.utils
import
async_tensor_h2d
,
make_tensor_with_pad
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
...
@@ -25,7 +27,8 @@ if TYPE_CHECKING:
...
@@ -25,7 +27,8 @@ if TYPE_CHECKING:
ModelInputForGPUWithSamplingMetadata
)
ModelInputForGPUWithSamplingMetadata
)
from
vllm.vllm_flash_attn
import
(
flash_attn_varlen_func
,
from
vllm.vllm_flash_attn
import
(
flash_attn_varlen_func
,
flash_attn_with_kvcache
)
flash_attn_with_kvcache
,
is_fa_version_supported
)
class
FlashAttentionBackend
(
AttentionBackend
):
class
FlashAttentionBackend
(
AttentionBackend
):
...
@@ -634,6 +637,20 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -634,6 +637,20 @@ class FlashAttentionImpl(AttentionImpl):
f
"Supported head sizes are:
{
support_head_sizes
}
."
)
f
"Supported head sizes are:
{
support_head_sizes
}
."
)
self
.
attn_type
=
attn_type
self
.
attn_type
=
attn_type
# if hopper default to FA3, otherwise stick to FA2 for now
# TODO(lucas): profile FA3 on ampere to see if it makes sense to
# use FA3 as default for both
if
current_platform
.
get_device_capability
()[
0
]
>=
9
:
self
.
fa_version
=
3
if
is_fa_version_supported
(
3
)
else
2
else
:
self
.
fa_version
=
2
if
VLLM_FLASH_ATTN_VERSION
is
not
None
:
assert
VLLM_FLASH_ATTN_VERSION
in
[
2
,
3
]
self
.
fa_version
=
VLLM_FLASH_ATTN_VERSION
assert
is_fa_version_supported
(
self
.
fa_version
)
def
forward
(
def
forward
(
self
,
self
,
layer
:
AttentionLayer
,
layer
:
AttentionLayer
,
...
@@ -752,6 +769,7 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -752,6 +769,7 @@ class FlashAttentionImpl(AttentionImpl):
alibi_slopes
=
alibi_slopes
,
alibi_slopes
=
alibi_slopes
,
softcap
=
logits_soft_cap
,
softcap
=
logits_soft_cap
,
out
=
prefill_output
,
out
=
prefill_output
,
fa_version
=
self
.
fa_version
,
)
)
else
:
else
:
# prefix-enabled attention
# prefix-enabled attention
...
@@ -765,7 +783,7 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -765,7 +783,7 @@ class FlashAttentionImpl(AttentionImpl):
v
=
value_cache
,
v
=
value_cache
,
cu_seqlens_q
=
prefill_meta
.
query_start_loc
,
cu_seqlens_q
=
prefill_meta
.
query_start_loc
,
max_seqlen_q
=
prefill_meta
.
max_query_len
,
max_seqlen_q
=
prefill_meta
.
max_query_len
,
cu_seqlens
_k
=
prefill_meta
.
seq_
start_loc
,
seqused
_k
=
prefill_meta
.
seq_
lens_tensor
,
max_seqlen_k
=
max_seq_len
,
max_seqlen_k
=
max_seq_len
,
softmax_scale
=
softmax_scale
,
softmax_scale
=
softmax_scale
,
causal
=
True
,
causal
=
True
,
...
@@ -774,6 +792,7 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -774,6 +792,7 @@ class FlashAttentionImpl(AttentionImpl):
block_table
=
prefill_meta
.
block_tables
,
block_table
=
prefill_meta
.
block_tables
,
softcap
=
logits_soft_cap
,
softcap
=
logits_soft_cap
,
out
=
prefill_output
,
out
=
prefill_output
,
fa_version
=
self
.
fa_version
,
)
)
if
decode_meta
:
=
attn_metadata
.
decode_metadata
:
if
decode_meta
:
=
attn_metadata
.
decode_metadata
:
...
@@ -793,7 +812,7 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -793,7 +812,7 @@ class FlashAttentionImpl(AttentionImpl):
v
=
value_cache
,
v
=
value_cache
,
cu_seqlens_q
=
decode_meta
.
query_start_loc
,
cu_seqlens_q
=
decode_meta
.
query_start_loc
,
max_seqlen_q
=
decode_meta
.
max_decode_query_len
,
max_seqlen_q
=
decode_meta
.
max_decode_query_len
,
cu_seqlens
_k
=
decode_meta
.
seq_
start_loc
,
seqused
_k
=
decode_meta
.
seq_
lens_tensor
,
max_seqlen_k
=
decode_meta
.
max_decode_seq_len
,
max_seqlen_k
=
decode_meta
.
max_decode_seq_len
,
softmax_scale
=
softmax_scale
,
softmax_scale
=
softmax_scale
,
causal
=
True
,
causal
=
True
,
...
@@ -802,6 +821,7 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -802,6 +821,7 @@ class FlashAttentionImpl(AttentionImpl):
softcap
=
logits_soft_cap
,
softcap
=
logits_soft_cap
,
block_table
=
decode_meta
.
block_tables
,
block_table
=
decode_meta
.
block_tables
,
out
=
decode_output
,
out
=
decode_output
,
fa_version
=
self
.
fa_version
,
)
)
else
:
else
:
# Use flash_attn_with_kvcache for normal decoding.
# Use flash_attn_with_kvcache for normal decoding.
...
@@ -822,6 +842,7 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -822,6 +842,7 @@ class FlashAttentionImpl(AttentionImpl):
alibi_slopes
=
alibi_slopes
,
alibi_slopes
=
alibi_slopes
,
softcap
=
logits_soft_cap
,
softcap
=
logits_soft_cap
,
out
=
decode_output
.
unsqueeze
(
1
),
out
=
decode_output
.
unsqueeze
(
1
),
fa_version
=
self
.
fa_version
,
)
)
return
output
return
output
...
...
vllm/envs.py
View file @
978b45f3
...
@@ -11,6 +11,7 @@ if TYPE_CHECKING:
...
@@ -11,6 +11,7 @@ if TYPE_CHECKING:
VLLM_NCCL_SO_PATH
:
Optional
[
str
]
=
None
VLLM_NCCL_SO_PATH
:
Optional
[
str
]
=
None
LD_LIBRARY_PATH
:
Optional
[
str
]
=
None
LD_LIBRARY_PATH
:
Optional
[
str
]
=
None
VLLM_USE_TRITON_FLASH_ATTN
:
bool
=
False
VLLM_USE_TRITON_FLASH_ATTN
:
bool
=
False
VLLM_FLASH_ATTN_VERSION
:
Optional
[
int
]
=
None
LOCAL_RANK
:
int
=
0
LOCAL_RANK
:
int
=
0
CUDA_VISIBLE_DEVICES
:
Optional
[
str
]
=
None
CUDA_VISIBLE_DEVICES
:
Optional
[
str
]
=
None
VLLM_ENGINE_ITERATION_TIMEOUT_S
:
int
=
60
VLLM_ENGINE_ITERATION_TIMEOUT_S
:
int
=
60
...
@@ -90,6 +91,12 @@ def get_default_config_root():
...
@@ -90,6 +91,12 @@ def get_default_config_root():
)
)
def
maybe_convert_int
(
value
:
Optional
[
str
])
->
Optional
[
int
]:
if
value
is
None
:
return
None
return
int
(
value
)
# The begin-* and end* here are used by the documentation generator
# The begin-* and end* here are used by the documentation generator
# to extract the used env vars.
# to extract the used env vars.
...
@@ -203,6 +210,11 @@ environment_variables: Dict[str, Callable[[], Any]] = {
...
@@ -203,6 +210,11 @@ environment_variables: Dict[str, Callable[[], Any]] = {
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_TRITON_FLASH_ATTN"
,
"True"
).
lower
()
in
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_TRITON_FLASH_ATTN"
,
"True"
).
lower
()
in
(
"true"
,
"1"
)),
(
"true"
,
"1"
)),
# Force vllm to use a specific flash-attention version (2 or 3), only valid
# when using the flash-attention backend.
"VLLM_FLASH_ATTN_VERSION"
:
lambda
:
maybe_convert_int
(
os
.
environ
.
get
(
"VLLM_FLASH_ATTN_VERSION"
,
None
)),
# Internal flag to enable Dynamo fullgraph capture
# Internal flag to enable Dynamo fullgraph capture
"VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"
:
"VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"
:
lambda
:
bool
(
lambda
:
bool
(
...
...
vllm/v1/attention/backends/flash_attn.py
View file @
978b45f3
...
@@ -9,8 +9,11 @@ import triton.language as tl
...
@@ -9,8 +9,11 @@ import triton.language as tl
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionMetadata
,
AttentionType
)
AttentionMetadata
,
AttentionType
)
from
vllm.envs
import
VLLM_FLASH_ATTN_VERSION
from
vllm.platforms
import
current_platform
from
vllm.utils
import
cdiv
from
vllm.utils
import
cdiv
from
vllm.vllm_flash_attn
import
flash_attn_varlen_func
from
vllm.vllm_flash_attn
import
(
flash_attn_varlen_func
,
is_fa_version_supported
)
class
FlashAttentionBackend
(
AttentionBackend
):
class
FlashAttentionBackend
(
AttentionBackend
):
...
@@ -63,7 +66,7 @@ class FlashAttentionMetadata:
...
@@ -63,7 +66,7 @@ class FlashAttentionMetadata:
max_query_len
:
int
max_query_len
:
int
query_start_loc
:
torch
.
Tensor
query_start_loc
:
torch
.
Tensor
max_seq_len
:
int
max_seq_len
:
int
seq_
start_loc
:
torch
.
Tensor
seq_
lens
:
torch
.
Tensor
block_table
:
torch
.
Tensor
block_table
:
torch
.
Tensor
slot_mapping
:
torch
.
Tensor
slot_mapping
:
torch
.
Tensor
...
@@ -71,8 +74,8 @@ class FlashAttentionMetadata:
...
@@ -71,8 +74,8 @@ class FlashAttentionMetadata:
use_cascade
:
bool
use_cascade
:
bool
common_prefix_len
:
int
common_prefix_len
:
int
cu_prefix_query_lens
:
Optional
[
torch
.
Tensor
]
cu_prefix_query_lens
:
Optional
[
torch
.
Tensor
]
cu_
prefix_kv_lens
:
Optional
[
torch
.
Tensor
]
prefix_kv_lens
:
Optional
[
torch
.
Tensor
]
cu_
suffix_kv_lens
:
Optional
[
torch
.
Tensor
]
suffix_kv_lens
:
Optional
[
torch
.
Tensor
]
# For logging.
# For logging.
num_input_tokens
:
int
=
0
# Number of tokens including padding.
num_input_tokens
:
int
=
0
# Number of tokens including padding.
...
@@ -128,6 +131,20 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -128,6 +131,20 @@ class FlashAttentionImpl(AttentionImpl):
"are not implemented for "
"are not implemented for "
"FlashAttentionImpl"
)
"FlashAttentionImpl"
)
# if hopper default to FA3, otherwise stick to FA2 for now
# TODO(lucas): profile FA3 on ampere to see if it makes sense to
# use FA3 as default for both
if
current_platform
.
get_device_capability
()[
0
]
>=
9
:
self
.
fa_version
=
3
if
is_fa_version_supported
(
3
)
else
2
else
:
self
.
fa_version
=
2
if
VLLM_FLASH_ATTN_VERSION
is
not
None
:
assert
VLLM_FLASH_ATTN_VERSION
in
[
2
,
3
]
self
.
fa_version
=
VLLM_FLASH_ATTN_VERSION
assert
is_fa_version_supported
(
self
.
fa_version
)
def
forward
(
def
forward
(
self
,
self
,
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
,
...
@@ -196,7 +213,7 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -196,7 +213,7 @@ class FlashAttentionImpl(AttentionImpl):
out
=
output
[:
num_actual_tokens
],
out
=
output
[:
num_actual_tokens
],
cu_seqlens_q
=
attn_metadata
.
query_start_loc
,
cu_seqlens_q
=
attn_metadata
.
query_start_loc
,
max_seqlen_q
=
attn_metadata
.
max_query_len
,
max_seqlen_q
=
attn_metadata
.
max_query_len
,
cu_seqlens
_k
=
attn_metadata
.
seq_
start_loc
,
seqused
_k
=
attn_metadata
.
seq_
lens
,
max_seqlen_k
=
attn_metadata
.
max_seq_len
,
max_seqlen_k
=
attn_metadata
.
max_seq_len
,
softmax_scale
=
self
.
scale
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
causal
=
True
,
...
@@ -204,6 +221,7 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -204,6 +221,7 @@ class FlashAttentionImpl(AttentionImpl):
window_size
=
self
.
sliding_window
,
window_size
=
self
.
sliding_window
,
block_table
=
attn_metadata
.
block_table
,
block_table
=
attn_metadata
.
block_table
,
softcap
=
self
.
logits_soft_cap
,
softcap
=
self
.
logits_soft_cap
,
fa_version
=
self
.
fa_version
,
)
)
return
output
return
output
...
@@ -216,8 +234,8 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -216,8 +234,8 @@ class FlashAttentionImpl(AttentionImpl):
cu_query_lens
=
attn_metadata
.
query_start_loc
,
cu_query_lens
=
attn_metadata
.
query_start_loc
,
max_query_len
=
attn_metadata
.
max_query_len
,
max_query_len
=
attn_metadata
.
max_query_len
,
cu_prefix_query_lens
=
attn_metadata
.
cu_prefix_query_lens
,
cu_prefix_query_lens
=
attn_metadata
.
cu_prefix_query_lens
,
cu_
prefix_kv_lens
=
attn_metadata
.
cu_
prefix_kv_lens
,
prefix_kv_lens
=
attn_metadata
.
prefix_kv_lens
,
cu_
suffix_kv_lens
=
attn_metadata
.
cu_
suffix_kv_lens
,
suffix_kv_lens
=
attn_metadata
.
suffix_kv_lens
,
max_kv_len
=
attn_metadata
.
max_seq_len
,
max_kv_len
=
attn_metadata
.
max_seq_len
,
softmax_scale
=
self
.
scale
,
softmax_scale
=
self
.
scale
,
alibi_slopes
=
self
.
alibi_slopes
,
alibi_slopes
=
self
.
alibi_slopes
,
...
@@ -225,6 +243,7 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -225,6 +243,7 @@ class FlashAttentionImpl(AttentionImpl):
logits_soft_cap
=
self
.
logits_soft_cap
,
logits_soft_cap
=
self
.
logits_soft_cap
,
block_table
=
attn_metadata
.
block_table
,
block_table
=
attn_metadata
.
block_table
,
common_prefix_len
=
attn_metadata
.
common_prefix_len
,
common_prefix_len
=
attn_metadata
.
common_prefix_len
,
fa_version
=
self
.
fa_version
,
)
)
return
output
return
output
...
@@ -305,8 +324,8 @@ def cascade_attention(
...
@@ -305,8 +324,8 @@ def cascade_attention(
cu_query_lens
:
torch
.
Tensor
,
cu_query_lens
:
torch
.
Tensor
,
max_query_len
:
int
,
max_query_len
:
int
,
cu_prefix_query_lens
:
torch
.
Tensor
,
cu_prefix_query_lens
:
torch
.
Tensor
,
cu_
prefix_kv_lens
:
torch
.
Tensor
,
prefix_kv_lens
:
torch
.
Tensor
,
cu_
suffix_kv_lens
:
torch
.
Tensor
,
suffix_kv_lens
:
torch
.
Tensor
,
max_kv_len
:
int
,
max_kv_len
:
int
,
softmax_scale
:
float
,
softmax_scale
:
float
,
alibi_slopes
:
Optional
[
torch
.
Tensor
],
alibi_slopes
:
Optional
[
torch
.
Tensor
],
...
@@ -314,6 +333,7 @@ def cascade_attention(
...
@@ -314,6 +333,7 @@ def cascade_attention(
logits_soft_cap
:
float
,
logits_soft_cap
:
float
,
block_table
:
torch
.
Tensor
,
block_table
:
torch
.
Tensor
,
common_prefix_len
:
int
,
common_prefix_len
:
int
,
fa_version
:
int
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
assert
alibi_slopes
is
None
,
(
"Cascade attention does not support ALiBi."
)
assert
alibi_slopes
is
None
,
(
"Cascade attention does not support ALiBi."
)
# TODO: Support sliding window.
# TODO: Support sliding window.
...
@@ -332,7 +352,7 @@ def cascade_attention(
...
@@ -332,7 +352,7 @@ def cascade_attention(
k
=
key_cache
,
k
=
key_cache
,
v
=
value_cache
,
v
=
value_cache
,
cu_seqlens_q
=
cu_prefix_query_lens
,
cu_seqlens_q
=
cu_prefix_query_lens
,
cu_seqlens_k
=
cu_
prefix_kv_lens
,
seqused_k
=
prefix_kv_lens
,
max_seqlen_q
=
num_tokens
,
max_seqlen_q
=
num_tokens
,
max_seqlen_k
=
common_prefix_len
,
max_seqlen_k
=
common_prefix_len
,
softmax_scale
=
softmax_scale
,
softmax_scale
=
softmax_scale
,
...
@@ -341,6 +361,7 @@ def cascade_attention(
...
@@ -341,6 +361,7 @@ def cascade_attention(
block_table
=
block_table
[:
1
],
block_table
=
block_table
[:
1
],
softcap
=
logits_soft_cap
,
softcap
=
logits_soft_cap
,
return_softmax_lse
=
True
,
return_softmax_lse
=
True
,
fa_version
=
fa_version
,
)
)
# Process suffix per query.
# Process suffix per query.
...
@@ -349,7 +370,7 @@ def cascade_attention(
...
@@ -349,7 +370,7 @@ def cascade_attention(
k
=
key_cache
,
k
=
key_cache
,
v
=
value_cache
,
v
=
value_cache
,
cu_seqlens_q
=
cu_query_lens
,
cu_seqlens_q
=
cu_query_lens
,
cu_seqlens_k
=
cu_
suffix_kv_lens
,
seqused_k
=
suffix_kv_lens
,
max_seqlen_q
=
max_query_len
,
max_seqlen_q
=
max_query_len
,
max_seqlen_k
=
max_kv_len
-
common_prefix_len
,
max_seqlen_k
=
max_kv_len
-
common_prefix_len
,
softmax_scale
=
softmax_scale
,
softmax_scale
=
softmax_scale
,
...
@@ -358,6 +379,7 @@ def cascade_attention(
...
@@ -358,6 +379,7 @@ def cascade_attention(
block_table
=
block_table
[:,
num_common_kv_blocks
:],
block_table
=
block_table
[:,
num_common_kv_blocks
:],
softcap
=
logits_soft_cap
,
softcap
=
logits_soft_cap
,
return_softmax_lse
=
True
,
return_softmax_lse
=
True
,
fa_version
=
fa_version
,
)
)
# Merge prefix and suffix outputs, and store the result in output.
# Merge prefix and suffix outputs, and store the result in output.
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
978b45f3
...
@@ -199,11 +199,11 @@ class GPUModelRunner:
...
@@ -199,11 +199,11 @@ class GPUModelRunner:
device
=
"cpu"
,
device
=
"cpu"
,
pin_memory
=
self
.
pin_memory
)
pin_memory
=
self
.
pin_memory
)
self
.
query_start_loc_np
=
self
.
query_start_loc_cpu
.
numpy
()
self
.
query_start_loc_np
=
self
.
query_start_loc_cpu
.
numpy
()
self
.
seq_
start_loc
_cpu
=
torch
.
zeros
(
self
.
max_num_reqs
+
1
,
self
.
seq_
lens
_cpu
=
torch
.
zeros
(
self
.
max_num_reqs
,
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
"cpu"
,
device
=
"cpu"
,
pin_memory
=
self
.
pin_memory
)
pin_memory
=
self
.
pin_memory
)
self
.
seq_
start_loc
_np
=
self
.
seq_
start_loc
_cpu
.
numpy
()
self
.
seq_
lens
_np
=
self
.
seq_
lens
_cpu
.
numpy
()
def
_update_states
(
self
,
scheduler_output
:
"SchedulerOutput"
)
->
None
:
def
_update_states
(
self
,
scheduler_output
:
"SchedulerOutput"
)
->
None
:
# Remove stopped requests from the cached states.
# Remove stopped requests from the cached states.
...
@@ -412,11 +412,10 @@ class GPUModelRunner:
...
@@ -412,11 +412,10 @@ class GPUModelRunner:
np
.
cumsum
(
num_scheduled_tokens
,
np
.
cumsum
(
num_scheduled_tokens
,
out
=
self
.
query_start_loc_np
[
1
:
num_reqs
+
1
])
out
=
self
.
query_start_loc_np
[
1
:
num_reqs
+
1
])
seq_lens
=
(
self
.
input_batch
.
num_computed_tokens_cpu
[:
num_reqs
]
+
self
.
seq_lens_np
[:
num_reqs
]
=
(
self
.
input_batch
.
num_computed_tokens_cpu
[:
num_reqs
]
+
num_scheduled_tokens
)
num_scheduled_tokens
)
max_seq_len
=
seq_lens
.
max
()
max_seq_len
=
self
.
seq_lens_np
[:
num_reqs
].
max
()
self
.
seq_start_loc_np
[
0
]
=
0
np
.
cumsum
(
seq_lens
,
out
=
self
.
seq_start_loc_np
[
1
:
num_reqs
+
1
])
# Copy the tensors to the GPU.
# Copy the tensors to the GPU.
self
.
input_ids
[:
total_num_scheduled_tokens
].
copy_
(
self
.
input_ids
[:
total_num_scheduled_tokens
].
copy_
(
...
@@ -433,8 +432,8 @@ class GPUModelRunner:
...
@@ -433,8 +432,8 @@ class GPUModelRunner:
non_blocking
=
True
)
non_blocking
=
True
)
query_start_loc
=
self
.
query_start_loc_cpu
[:
num_reqs
+
1
].
to
(
query_start_loc
=
self
.
query_start_loc_cpu
[:
num_reqs
+
1
].
to
(
self
.
device
,
non_blocking
=
True
)
self
.
device
,
non_blocking
=
True
)
seq_
start_loc
=
self
.
seq_
start_loc
_cpu
[:
num_reqs
+
1
].
to
(
seq_
lens
=
self
.
seq_
lens
_cpu
[:
num_reqs
].
to
(
self
.
device
,
self
.
device
,
non_blocking
=
True
)
non_blocking
=
True
)
slot_mapping
=
self
.
slot_mapping_cpu
[:
total_num_scheduled_tokens
].
to
(
slot_mapping
=
self
.
slot_mapping_cpu
[:
total_num_scheduled_tokens
].
to
(
self
.
device
,
non_blocking
=
True
).
long
()
self
.
device
,
non_blocking
=
True
).
long
()
...
@@ -506,33 +505,30 @@ class GPUModelRunner:
...
@@ -506,33 +505,30 @@ class GPUModelRunner:
[
0
,
total_num_scheduled_tokens
],
[
0
,
total_num_scheduled_tokens
],
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
device
=
self
.
device
)
cu_
prefix_kv_lens
=
torch
.
tensor
([
0
,
common_prefix_len
],
prefix_kv_lens
=
torch
.
tensor
([
common_prefix_len
],
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
device
=
self
.
device
)
cu_suffix_kv_lens
=
(
suffix_kv_lens
=
(
self
.
seq_lens_np
[:
num_reqs
]
-
common_prefix_len
)
self
.
seq_start_loc_np
[:
num_reqs
+
1
]
-
suffix_kv_lens
=
torch
.
from_numpy
(
suffix_kv_lens
).
to
(
self
.
device
)
self
.
arange_np
[:
num_reqs
+
1
]
*
common_prefix_len
)
cu_suffix_kv_lens
=
torch
.
from_numpy
(
cu_suffix_kv_lens
).
to
(
self
.
device
)
else
:
else
:
cu_prefix_query_lens
=
None
cu_prefix_query_lens
=
None
cu_
prefix_kv_lens
=
None
prefix_kv_lens
=
None
cu_
suffix_kv_lens
=
None
suffix_kv_lens
=
None
attn_metadata
=
FlashAttentionMetadata
(
attn_metadata
=
FlashAttentionMetadata
(
num_actual_tokens
=
total_num_scheduled_tokens
,
num_actual_tokens
=
total_num_scheduled_tokens
,
max_query_len
=
max_num_scheduled_tokens
,
max_query_len
=
max_num_scheduled_tokens
,
query_start_loc
=
query_start_loc
,
query_start_loc
=
query_start_loc
,
max_seq_len
=
max_seq_len
,
max_seq_len
=
max_seq_len
,
seq_
start_loc
=
seq_start_loc
,
seq_
lens
=
seq_lens
,
block_table
=
(
block_table
=
(
self
.
input_batch
.
block_table
.
get_device_tensor
()[:
num_reqs
]),
self
.
input_batch
.
block_table
.
get_device_tensor
()[:
num_reqs
]),
slot_mapping
=
slot_mapping
,
slot_mapping
=
slot_mapping
,
use_cascade
=
use_cascade
,
use_cascade
=
use_cascade
,
common_prefix_len
=
common_prefix_len
,
common_prefix_len
=
common_prefix_len
,
cu_prefix_query_lens
=
cu_prefix_query_lens
,
cu_prefix_query_lens
=
cu_prefix_query_lens
,
cu_
prefix_kv_lens
=
cu_
prefix_kv_lens
,
prefix_kv_lens
=
prefix_kv_lens
,
cu_
suffix_kv_lens
=
cu_
suffix_kv_lens
,
suffix_kv_lens
=
suffix_kv_lens
,
)
)
# NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial
# NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial
# request in the batch. While we should not sample any token from this
# request in the batch. While we should not sample any token from this
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment