Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4ddb5447
Commit
4ddb5447
authored
Jun 27, 2025
by
zhuwenwen
Browse files
update cutlass fa and pa
parent
fdda4d82
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
237 additions
and
55 deletions
+237
-55
csrc/cache.h
csrc/cache.h
+6
-0
csrc/cache_kernels.cu
csrc/cache_kernels.cu
+110
-0
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+9
-0
examples/offline_inference/basic/basic.py
examples/offline_inference/basic/basic.py
+2
-1
vllm/_custom_ops.py
vllm/_custom_ops.py
+15
-0
vllm/attention/backends/rocm_flash_attn.py
vllm/attention/backends/rocm_flash_attn.py
+49
-38
vllm/attention/ops/paged_attn.py
vllm/attention/ops/paged_attn.py
+40
-16
vllm/envs.py
vllm/envs.py
+6
-0
No files found.
csrc/cache.h
View file @
4ddb5447
...
...
@@ -24,6 +24,12 @@ void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
const
std
::
string
&
kv_cache_dtype
,
torch
::
Tensor
&
k_scale
,
torch
::
Tensor
&
v_scale
);
void
reshape_and_cache_cuda
(
torch
::
Tensor
&
key
,
torch
::
Tensor
&
value
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
torch
::
Tensor
&
slot_mapping
,
const
std
::
string
&
kv_cache_dtype
,
torch
::
Tensor
&
k_scale
,
torch
::
Tensor
&
v_scale
);
void
reshape_and_cache_flash
(
torch
::
Tensor
&
key
,
torch
::
Tensor
&
value
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
...
...
csrc/cache_kernels.cu
View file @
4ddb5447
...
...
@@ -270,6 +270,66 @@ __global__ void reshape_and_cache_kernel(
}
}
template
<
typename
scalar_t
,
typename
cache_t
,
Fp8KVCacheDataType
kv_dt
>
__global__
void
reshape_and_cache_kernel_cuda
(
const
scalar_t
*
__restrict__
key
,
// [num_tokens, num_heads, head_size]
const
scalar_t
*
__restrict__
value
,
// [num_tokens, num_heads, head_size]
cache_t
*
__restrict__
key_cache
,
// [num_blocks, num_heads, block_size, head_size] target layout
cache_t
*
__restrict__
value_cache
,
// [num_blocks, num_heads, head_size, block_size]
const
int64_t
*
__restrict__
slot_mapping
,
// [num_tokens]
const
int
key_stride
,
const
int
value_stride
,
const
int
num_heads
,
const
int
head_size
,
const
int
block_size
,
int
x
,
const
float
*
k_scale
,
const
float
*
v_scale
)
{
const
int64_t
token_idx
=
blockIdx
.
x
;
const
int64_t
slot_idx
=
slot_mapping
[
token_idx
];
if
(
slot_idx
<
0
)
{
// Padding token that should be ignored.
return
;
}
const
int64_t
block_idx
=
slot_idx
/
block_size
;
const
int64_t
block_offset
=
slot_idx
%
block_size
;
const
int
n
=
num_heads
*
head_size
;
for
(
int
i
=
threadIdx
.
x
;
i
<
n
;
i
+=
blockDim
.
x
)
{
const
int64_t
src_key_idx
=
token_idx
*
key_stride
+
i
;
const
int64_t
src_value_idx
=
token_idx
*
value_stride
+
i
;
const
int
head_idx
=
i
/
head_size
;
const
int
head_offset
=
i
%
head_size
;
// ---------- calculate target index ----------
// K: [num_blocks, num_heads, block_size, head_size]
const
int64_t
tgt_key_idx
=
block_idx
*
num_heads
*
block_size
*
head_size
+
head_idx
*
block_size
*
head_size
+
block_offset
*
head_size
+
head_offset
;
// V: [num_blocks, num_heads, head_size, block_size]
const
int64_t
tgt_value_idx
=
block_idx
*
num_heads
*
head_size
*
block_size
+
head_idx
*
head_size
*
block_size
+
head_offset
*
block_size
+
block_offset
;
scalar_t
tgt_key
=
key
[
src_key_idx
];
scalar_t
tgt_value
=
value
[
src_value_idx
];
if
constexpr
(
kv_dt
==
Fp8KVCacheDataType
::
kAuto
)
{
key_cache
[
tgt_key_idx
]
=
tgt_key
;
value_cache
[
tgt_value_idx
]
=
tgt_value
;
}
else
if
constexpr
(
kv_dt
==
Fp8KVCacheDataType
::
kInt8
)
{
key_cache
[
tgt_key_idx
]
=
int8
::
scaled_vec_conversion_int8
<
cache_t
,
scalar_t
>
(
tgt_key
,
*
k_scale
);
value_cache
[
tgt_value_idx
]
=
int8
::
scaled_vec_conversion_int8
<
cache_t
,
scalar_t
>
(
tgt_value
,
*
v_scale
);
}
else
{
key_cache
[
tgt_key_idx
]
=
fp8
::
scaled_convert
<
cache_t
,
scalar_t
,
kv_dt
>
(
tgt_key
,
*
k_scale
);
value_cache
[
tgt_value_idx
]
=
fp8
::
scaled_convert
<
cache_t
,
scalar_t
,
kv_dt
>
(
tgt_value
,
*
v_scale
);
}
}
}
template
<
typename
scalar_t
,
typename
cache_t
,
Fp8KVCacheDataType
kv_dt
>
__global__
void
reshape_and_cache_flash_kernel
(
const
scalar_t
*
__restrict__
key
,
// [num_tokens, num_heads, head_size]
...
...
@@ -538,6 +598,56 @@ void reshape_and_cache(
CALL_RESHAPE_AND_CACHE
)
}
#define CALL_RESHAPE_AND_CACHE_CUDA(KV_T, CACHE_T, KV_DTYPE) \
vllm::reshape_and_cache_kernel_cuda<KV_T, CACHE_T, KV_DTYPE> \
<<<grid, block, 0, stream>>>( \
reinterpret_cast<KV_T*>(key.data_ptr()), \
reinterpret_cast<KV_T*>(value.data_ptr()), \
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
slot_mapping.data_ptr<int64_t>(), key_stride, value_stride, \
num_heads, head_size, block_size, 1, \
reinterpret_cast<const float*>(k_scale.data_ptr()), \
reinterpret_cast<const float*>(v_scale.data_ptr()));
void
reshape_and_cache_cuda
(
torch
::
Tensor
&
key
,
// [num_tokens, num_heads, head_size]
torch
::
Tensor
&
value
,
// [num_tokens, num_heads, head_size]
torch
::
Tensor
&
key_cache
,
// [num_blocks, num_heads, block_size, head_size]
torch
::
Tensor
&
value_cache
,
// [num_blocks, num_heads, head_size, block_size]
torch
::
Tensor
&
slot_mapping
,
// [num_tokens]
const
std
::
string
&
kv_cache_dtype
,
torch
::
Tensor
&
k_scale
,
torch
::
Tensor
&
v_scale
)
{
TORCH_CHECK
(
key
.
dim
()
==
3
&&
value
.
dim
()
==
3
,
"key/value must be [num_tokens, num_heads, head_size]"
);
TORCH_CHECK
(
key_cache
.
dim
()
==
4
&&
value_cache
.
dim
()
==
4
,
"cache tensor shape mismatch"
);
TORCH_CHECK
(
key_cache
.
size
(
0
)
==
value_cache
.
size
(
0
)
&&
key_cache
.
size
(
1
)
==
value_cache
.
size
(
1
)
&&
key_cache
.
size
(
2
)
==
value_cache
.
size
(
3
)
&&
key_cache
.
size
(
3
)
==
value_cache
.
size
(
2
),
"key/value cache dimension mismatch"
);
int
num_tokens
=
slot_mapping
.
size
(
0
);
int
num_heads
=
key
.
size
(
1
);
int
head_size
=
key
.
size
(
2
);
int
block_size
=
key_cache
.
size
(
2
);
// k layout: [num_blocks, num_heads, block_size, head_size]
int
key_stride
=
key
.
stride
(
0
);
int
value_stride
=
value
.
stride
(
0
);
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
(
num_heads
*
head_size
,
512
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
key
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
DISPATCH_BY_KV_CACHE_DTYPE
(
key
.
dtype
(),
kv_cache_dtype
,
CALL_RESHAPE_AND_CACHE_CUDA
);
}
// KV_T is the data type of key and value tensors.
// CACHE_T is the stored data type of kv-cache.
// KV_DTYPE is the real data type of kv-cache.
...
...
csrc/torch_bindings.cpp
View file @
4ddb5447
...
...
@@ -845,6 +845,15 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
" Tensor k_scale, Tensor v_scale) -> ()"
);
cache_ops
.
impl
(
"reshape_and_cache"
,
torch
::
kCUDA
,
&
reshape_and_cache
);
// Reshape the key(new) and value tensors and cache them.
cache_ops
.
def
(
"reshape_and_cache_cuda(Tensor key, Tensor value, "
"Tensor! key_cache, Tensor! value_cache, Tensor slot_mapping, "
"str kv_cache_dtype, Tensor k_scale, Tensor v_scale) -> ()"
);
cache_ops
.
impl
(
"reshape_and_cache_cuda"
,
torch
::
kCUDA
,
&
reshape_and_cache_cuda
);
// Reshape the key and value tensors and cache them.
cache_ops
.
def
(
"reshape_and_cache_flash(Tensor key, Tensor value,"
...
...
examples/offline_inference/basic/basic.py
View file @
4ddb5447
...
...
@@ -9,6 +9,7 @@ prompts = [
"The president of the United States is"
,
"The capital of France is"
,
"The future of AI is"
,
"Hello, my name is"
,
]
# Create a sampling params object.
sampling_params
=
SamplingParams
(
temperature
=
0.8
,
top_p
=
0.95
,
max_tokens
=
16
)
...
...
@@ -16,7 +17,7 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=16)
def
main
():
# Create an LLM.
llm
=
LLM
(
model
=
"facebook/opt-125m"
,
tensor_parallel_size
=
1
,
dtype
=
"float16"
,
trust_remote_code
=
True
,
enforce_eager
=
True
)
llm
=
LLM
(
model
=
"facebook/opt-125m"
,
tensor_parallel_size
=
1
,
dtype
=
"float16"
,
trust_remote_code
=
True
,
enforce_eager
=
True
,
block_size
=
16
,
enable_prefix_caching
=
False
)
# Generate texts from the prompts.
# The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
...
...
vllm/_custom_ops.py
View file @
4ddb5447
...
...
@@ -2070,6 +2070,21 @@ def reshape_and_cache(
kv_cache_dtype
,
k_scale
,
v_scale
)
def
reshape_and_cache_cuda
(
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
kv_cache_dtype
:
str
,
k_scale
:
torch
.
Tensor
,
v_scale
:
torch
.
Tensor
,
)
->
None
:
torch
.
ops
.
_C_cache_ops
.
reshape_and_cache_cuda
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
,
kv_cache_dtype
,
k_scale
,
v_scale
)
def
reshape_and_cache_flash
(
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
...
...
vllm/attention/backends/rocm_flash_attn.py
View file @
4ddb5447
...
...
@@ -580,17 +580,17 @@ class ROCmFlashAttentionImpl(AttentionImpl):
else
:
if
SUPPORT_TC
:
try
:
from
flash_attn
import
flash_attn_varlen_func
# noqa: F401
from
flash_attn
import
flash_attn_varlen_func
,
vllm_flash_attn_varlen_func
# , vllm_flash_attn_with_kvcache
# noqa: F401
self
.
fa_attn_func
=
flash_attn_varlen_func
if
not
envs
.
VLLM_USE_TRITON_PREFIX_FLASH_ATTN
and
gpuname
.
startswith
(
'K100_AI'
):
from
flash_attn
import
vllm_flash_attn_varlen_func
self
.
fa_prefix_attn_func
=
vllm_flash_attn_varlen_func
self
.
fa_prefix_attn_func
=
vllm_flash_attn_varlen_func
# self.fa_decode_attn_func = vllm_flash_attn_with_kvcache
logger
.
debug
(
"Using CUTLASS FA in ROCmBackend"
)
except
ModuleNotFoundError
:
self
.
use_naive_attn
=
True
else
:
self
.
use_naive_attn
=
True
envs
.
VLLM_USE_TRITON_PREFIX_FLASH_ATTN
=
True
if
self
.
use_naive_attn
:
if
logits_soft_cap
is
not
None
:
...
...
@@ -857,7 +857,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
else
:
# prefix-enabled attention -
# not applicable for encoder-only models
if
envs
.
VLLM_USE_TRITON_PREFIX_FLASH_ATTN
or
(
not
gpuname
.
startswith
(
'K100_AI'
))
:
if
envs
.
VLLM_USE_TRITON_PREFIX_FLASH_ATTN
:
version_key
=
triton_key
()
if
self
.
attn_type
!=
AttentionType
.
ENCODER_ONLY
:
output
[:
num_prefill_tokens
]
=
paged_attn
.
forward_prefix
(
...
...
@@ -889,19 +889,10 @@ class ROCmFlashAttentionImpl(AttentionImpl):
triton: [GPU blocks, num_kv_heads, head_size // x, block_size, x] --->
cutlass: num_blocks x page_block_size x num_heads_k x head_size i
'''
num_blocks
,
num_kv_heads
,
head_size_div_x
,
block_size
,
x
=
key_cache
.
shape
head_size
=
head_size_div_x
*
x
key_cache_flash
=
key_cache
.
permute
(
0
,
3
,
1
,
2
,
4
)
# [num_blocks, block_size, num_kv_heads, head_size//x, x]
key_cache_flash
=
key_cache_flash
.
reshape
(
num_blocks
,
block_size
,
num_kv_heads
,
head_size
)
# value_cache
value_cache_flash
=
value_cache
.
permute
(
0
,
3
,
1
,
2
)
# [num_blocks, block_size, num_kv_heads, head_size]
output
[:
num_prefill_tokens
]
=
self
.
fa_prefix_attn_func
(
# noqa
q
=
query
,
k
=
key_cache
_flash
,
v
=
value_cache
_flash
,
k
=
key_cache
,
v
=
value_cache
,
cu_seqlens_q
=
prefill_meta
.
query_start_loc
,
max_seqlen_q
=
prefill_meta
.
max_query_len
,
seqused_k
=
prefill_meta
.
seq_lens_tensor
,
...
...
@@ -977,28 +968,48 @@ class ROCmFlashAttentionImpl(AttentionImpl):
)
else
:
tree_attention_masks_tensor
=
decode_meta
.
tree_attention_masks_tensor
output
[
num_prefill_tokens
:]
=
paged_attn
.
forward_decode
(
decode_query
,
key_cache
,
value_cache
,
decode_meta
.
block_tables
if
self
.
attn_type
!=
AttentionType
.
ENCODER_DECODER
else
decode_meta
.
cross_block_tables
,
decode_meta
.
seq_lens_tensor
if
self
.
attn_type
!=
AttentionType
.
ENCODER_DECODER
else
decode_meta
.
encoder_seq_lens_tensor
,
decode_meta
.
max_decode_seq_len
if
self
.
attn_type
!=
AttentionType
.
ENCODER_DECODER
else
decode_meta
.
max_encoder_seq_len
,
self
.
kv_cache_dtype
,
self
.
num_kv_heads
,
self
.
scale
,
self
.
alibi_slopes
,
layer
.
_k_scale
,
layer
.
_v_scale
,
attn_masks
=
tree_attention_masks_tensor
,
attn_masks_stride
=
tree_attention_masks_tensor
.
stride
(
0
)
if
tree_attention_masks_tensor
is
not
None
else
0
)
if
envs
.
VLLM_USE_FLASH_ATTN_BACKEND
:
from
flash_attn
import
vllm_flash_attn_with_kvcache
# output[num_prefill_tokens:] = self.fa_decode_attn_func(
output
[
num_prefill_tokens
:]
=
vllm_flash_attn_with_kvcache
(
q
=
decode_query
.
unsqueeze
(
1
),
k_cache
=
key_cache
,
v_cache
=
value_cache
,
cache_seqlens
=
decode_meta
.
seq_lens_tensor
,
block_table
=
decode_meta
.
block_tables
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
window_size
=
self
.
sliding_window
,
softcap
=
self
.
logits_soft_cap
,
alibi_slopes
=
self
.
alibi_slopes
,
return_softmax_lse
=
False
,
k_scale
=
layer
.
_k_scale
,
v_scale
=
layer
.
_v_scale
,
kv_cache_dtype
=
self
.
kv_cache_dtype
,
).
squeeze
(
1
)
else
:
output
[
num_prefill_tokens
:]
=
paged_attn
.
forward_decode
(
decode_query
,
key_cache
,
value_cache
,
decode_meta
.
block_tables
if
self
.
attn_type
!=
AttentionType
.
ENCODER_DECODER
else
decode_meta
.
cross_block_tables
,
decode_meta
.
seq_lens_tensor
if
self
.
attn_type
!=
AttentionType
.
ENCODER_DECODER
else
decode_meta
.
encoder_seq_lens_tensor
,
decode_meta
.
max_decode_seq_len
if
self
.
attn_type
!=
AttentionType
.
ENCODER_DECODER
else
decode_meta
.
max_encoder_seq_len
,
self
.
kv_cache_dtype
,
self
.
num_kv_heads
,
self
.
scale
,
self
.
alibi_slopes
,
layer
.
_k_scale
,
layer
.
_v_scale
,
attn_masks
=
tree_attention_masks_tensor
,
attn_masks_stride
=
tree_attention_masks_tensor
.
stride
(
0
)
if
tree_attention_masks_tensor
is
not
None
else
0
)
# Reshape the output tensor.
return
output
.
view
(
-
1
,
self
.
num_heads
*
self
.
head_size
)
...
...
vllm/attention/ops/paged_attn.py
View file @
4ddb5447
...
...
@@ -58,12 +58,23 @@ class PagedAttention:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
x
=
16
//
kv_cache
.
element_size
()
num_blocks
=
kv_cache
.
shape
[
1
]
key_cache
=
kv_cache
[
0
]
key_cache
=
key_cache
.
view
(
num_blocks
,
num_kv_heads
,
head_size
//
x
,
-
1
,
x
)
value_cache
=
kv_cache
[
1
]
value_cache
=
value_cache
.
view
(
num_blocks
,
num_kv_heads
,
head_size
,
-
1
)
'''
CUTLASS key_cache layout: [num_blocks, num_kv_heads, block_size, head_size]
Triton key_cache layout: [num_blocks, num_kv_heads, head_size // x, block_size, x]
value_cache layout: [num_blocks, num_kv_heads, head_size, block_size]
'''
if
envs
.
VLLM_USE_FLASH_ATTN_BACKEND
:
key_cache
=
kv_cache
[
0
]
key_cache
=
key_cache
.
view
(
num_blocks
,
num_kv_heads
,
-
1
,
head_size
)
value_cache
=
kv_cache
[
1
]
value_cache
=
value_cache
.
view
(
num_blocks
,
num_kv_heads
,
head_size
,
-
1
)
else
:
key_cache
=
kv_cache
[
0
]
key_cache
=
key_cache
.
view
(
num_blocks
,
num_kv_heads
,
head_size
//
x
,
-
1
,
x
)
value_cache
=
kv_cache
[
1
]
value_cache
=
value_cache
.
view
(
num_blocks
,
num_kv_heads
,
head_size
,
-
1
)
return
key_cache
,
value_cache
@
staticmethod
...
...
@@ -77,16 +88,29 @@ class PagedAttention:
k_scale
:
torch
.
Tensor
,
v_scale
:
torch
.
Tensor
,
)
->
None
:
ops
.
reshape_and_cache
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
.
flatten
(),
kv_cache_dtype
,
k_scale
,
v_scale
,
)
if
envs
.
VLLM_USE_FLASH_ATTN_BACKEND
:
ops
.
reshape_and_cache_cuda
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
.
flatten
(),
kv_cache_dtype
,
k_scale
,
v_scale
,
)
else
:
ops
.
reshape_and_cache
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
.
flatten
(),
kv_cache_dtype
,
k_scale
,
v_scale
,
)
@
staticmethod
def
forward_decode
(
...
...
vllm/envs.py
View file @
4ddb5447
...
...
@@ -150,6 +150,7 @@ if TYPE_CHECKING:
VLLM_TBO_DECODE_BS
:
int
=
0
VLLM_ZERO_OVERHEAD
:
bool
=
False
VLLM_ENABLE_MOE_FUSED_GATE
:
bool
=
False
VLLM_USE_FLASH_ATTN_BACKEND
:
bool
=
False
def
get_default_cache_root
():
...
...
@@ -990,6 +991,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
# If set, vLLM will enable the moe_fused_gate kernel.
"VLLM_ENABLE_MOE_FUSED_GATE"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_ENABLE_MOE_FUSED_GATE"
,
"1"
))),
# vLLM will use FlashAttention Backend for attention computation on rocm
"VLLM_USE_FLASH_ATTN_BACKEND"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_FLASH_ATTN_BACKEND"
,
"False"
).
lower
()
in
(
"true"
,
"1"
)),
}
# --8<-- [end:env-vars-definition]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment