Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b918400d
Commit
b918400d
authored
May 21, 2025
by
zhuwenwen
Browse files
Merge remote-tracking branch 'origin/vllm-0.8.5-zhangshao' into v0.8.5.post1-dev
parents
8fb5dea5
e02d110d
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
245 additions
and
55 deletions
+245
-55
csrc/attention/attention_kernels_opt.cu
csrc/attention/attention_kernels_opt.cu
+0
-1
csrc/attention/attention_kernels_opt_tc.cu
csrc/attention/attention_kernels_opt_tc.cu
+118
-23
csrc/attention/attention_with_mask_kernels_opt_tc.cu
csrc/attention/attention_with_mask_kernels_opt_tc.cu
+117
-24
vllm/attention/ops/paged_attn.py
vllm/attention/ops/paged_attn.py
+1
-4
vllm/utils.py
vllm/utils.py
+9
-3
No files found.
csrc/attention/attention_kernels_opt.cu
View file @
b918400d
...
@@ -487,7 +487,6 @@ __device__ void paged_attention_kernel_opt(
...
@@ -487,7 +487,6 @@ __device__ void paged_attention_kernel_opt(
if
constexpr
(
KV_DTYPE
==
Fp8KVCacheDataType
::
kAuto
)
{
if
constexpr
(
KV_DTYPE
==
Fp8KVCacheDataType
::
kAuto
)
{
v_vec
=
*
reinterpret_cast
<
const
V_vec
*>
(
v_ptr
+
offset
);
v_vec
=
*
reinterpret_cast
<
const
V_vec
*>
(
v_ptr
+
offset
);
}
else
if
constexpr
(
KV_DTYPE
==
Fp8KVCacheDataType
::
kInt8
)
{
}
else
if
constexpr
(
KV_DTYPE
==
Fp8KVCacheDataType
::
kInt8
)
{
// printf("======xiabo_kvint8\n");
V_quant_vec
v_quant_vec
=
V_quant_vec
v_quant_vec
=
*
reinterpret_cast
<
const
V_quant_vec
*>
(
v_ptr
+
offset
);
*
reinterpret_cast
<
const
V_quant_vec
*>
(
v_ptr
+
offset
);
// Vector conversion from V_quant_vec to V_vec.
// Vector conversion from V_quant_vec to V_vec.
...
...
csrc/attention/attention_kernels_opt_tc.cu
View file @
b918400d
...
@@ -5,6 +5,7 @@
...
@@ -5,6 +5,7 @@
#include "attention_dtypes.h"
#include "attention_dtypes.h"
#include "attention_utils.cuh"
#include "attention_utils.cuh"
#include "../quantization/int8_kvcache/quant_utils.cuh"
#include <hip/hip_bf16.h>
#include <hip/hip_bf16.h>
...
@@ -88,6 +89,7 @@ inline __device__ float block_sum(float* red_smem, float sum) {
...
@@ -88,6 +89,7 @@ inline __device__ float block_sum(float* red_smem, float sum) {
return
VLLM_SHFL_SYNC
(
sum
,
0
);
return
VLLM_SHFL_SYNC
(
sum
,
0
);
}
}
using
uint8x4_t
=
__attribute__
(
(
__vector_size__
(
4
*
sizeof
(
uint8_t
))
))
uint8_t
;
using
half4_t
=
__attribute__
(
(
__vector_size__
(
4
*
sizeof
(
_Float16
))
))
_Float16
;
using
half4_t
=
__attribute__
(
(
__vector_size__
(
4
*
sizeof
(
_Float16
))
))
_Float16
;
using
v4bh
=
__attribute__
(
(
__vector_size__
(
4
*
sizeof
(
short
))
))
short
;
using
v4bh
=
__attribute__
(
(
__vector_size__
(
4
*
sizeof
(
short
))
))
short
;
using
float4_t
=
__attribute__
(
(
__vector_size__
(
4
*
sizeof
(
float
))
))
float
;
using
float4_t
=
__attribute__
(
(
__vector_size__
(
4
*
sizeof
(
float
))
))
float
;
...
@@ -95,12 +97,62 @@ using float2_t = __attribute__( (__vector_size__(2 * sizeof(float)) )) float;
...
@@ -95,12 +97,62 @@ using float2_t = __attribute__( (__vector_size__(2 * sizeof(float)) )) float;
struct
half4x2
{
struct
half4x2
{
half4_t
data
[
2
];
half4_t
data
[
2
];
};
};
struct
uint8x4x4
{
uint8x4_t
data
[
4
];
};
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
struct
vec2data
{
struct
vec2data
{
scalar_t
data
[
2
];
scalar_t
data
[
2
];
};
};
inline
__device__
float
uint82float
(
const
uint8_t
&
input
)
{
const
uint32_t
w
=
(
uint32_t
)
input
<<
24
;
const
uint32_t
sign
=
w
&
UINT32_C
(
0x80000000
);
const
uint32_t
nonsign
=
w
&
UINT32_C
(
0x7FFFFFFF
);
uint32_t
renorm_shift
=
__clz
(
nonsign
);
renorm_shift
=
renorm_shift
>
4
?
renorm_shift
-
4
:
0
;
uint32_t
result
=
sign
|
((
nonsign
<<
renorm_shift
>>
4
)
+
((
0x78
-
renorm_shift
)
<<
23
));
return
c10
::
detail
::
fp32_from_bits
(
result
);
}
template
<
bool
is_half
,
bool
is_fp8
>
inline
__device__
half4_t
int8x4_to_half4
(
uint8x4_t
x
,
const
float
scale
)
{
half4_t
ret
;
if
constexpr
(
is_fp8
){
if
constexpr
(
is_half
){
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
){
ret
[
i
]
=
uint82float
(
x
[
i
])
*
scale
;
}
}
else
{
__nv_bfloat16
*
bd
=
reinterpret_cast
<
__nv_bfloat16
*>
(
&
ret
);
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
){
bd
[
i
]
=
__float2bfloat16
(
uint82float
(
x
[
i
])
*
scale
);
}
}
}
else
{
if
constexpr
(
is_half
){
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
){
ret
[
i
]
=
(
x
[
i
]
-
128.0
f
)
*
scale
;
}
}
else
{
__nv_bfloat16
*
bd
=
reinterpret_cast
<
__nv_bfloat16
*>
(
&
ret
);
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
){
bd
[
i
]
=
__float2bfloat16
((
x
[
i
]
-
128.0
f
)
*
scale
);
}
}
}
return
ret
;
}
template
<
bool
is_half
>
template
<
bool
is_half
>
inline
__device__
void
float4_2_half4
(
half4_t
&
dst
,
const
float4_t
&
src
)
inline
__device__
void
float4_2_half4
(
half4_t
&
dst
,
const
float4_t
&
src
)
{
{
...
@@ -165,7 +217,7 @@ __global__ void paged_attention_kernel_TC(
...
@@ -165,7 +217,7 @@ __global__ void paged_attention_kernel_TC(
const
int
max_num_blocks_per_seq
,
const
int
max_num_blocks_per_seq
,
const
float
*
__restrict__
alibi_slopes
,
// [num_heads]
const
float
*
__restrict__
alibi_slopes
,
// [num_heads]
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
,
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
,
const
float
*
k_scale
,
const
float
*
v_scale
,
const
int
tp_rank
,
const
float
*
k_scale
_ptr
,
const
float
*
v_scale
_ptr
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
,
int
PARTITION_SIZE
=
0
)
{
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
,
int
PARTITION_SIZE
=
0
)
{
#if defined(__gfx936__) || defined(__gfx928__)
#if defined(__gfx936__) || defined(__gfx928__)
...
@@ -177,6 +229,7 @@ __global__ void paged_attention_kernel_TC(
...
@@ -177,6 +229,7 @@ __global__ void paged_attention_kernel_TC(
const
bool
USE_PARTITIONING
=
PARTITION_SIZE
<
num_seq_blocks
*
BLOCK_SIZE
&&
PARTITION_SIZE
>
0
;
const
bool
USE_PARTITIONING
=
PARTITION_SIZE
<
num_seq_blocks
*
BLOCK_SIZE
&&
PARTITION_SIZE
>
0
;
if
(
partition_idx
*
PARTITION_SIZE
>=
seq_len
)
return
;
if
(
partition_idx
*
PARTITION_SIZE
>=
seq_len
)
return
;
constexpr
bool
is_half
=
std
::
is_same
<
scalar_t
,
uint16_t
>::
value
;
constexpr
bool
is_half
=
std
::
is_same
<
scalar_t
,
uint16_t
>::
value
;
constexpr
bool
is_fp8
=
(
KV_DTYPE
==
Fp8KVCacheDataType
::
kFp8E4M3
);
static_assert
(
HEAD_SIZE
<=
4
*
NUM_THREADS
,
"HEAD_SIZE<=4*NUM_THREADS"
);
static_assert
(
HEAD_SIZE
<=
4
*
NUM_THREADS
,
"HEAD_SIZE<=4*NUM_THREADS"
);
const
int
num_blocks_per_partition
=
USE_PARTITIONING
?
PARTITION_SIZE
/
BLOCK_SIZE
:
num_seq_blocks
;
const
int
num_blocks_per_partition
=
USE_PARTITIONING
?
PARTITION_SIZE
/
BLOCK_SIZE
:
num_seq_blocks
;
const
int
partition_size
=
USE_PARTITIONING
?
PARTITION_SIZE
:
num_seq_blocks
*
BLOCK_SIZE
;
const
int
partition_size
=
USE_PARTITIONING
?
PARTITION_SIZE
:
num_seq_blocks
*
BLOCK_SIZE
;
...
@@ -193,7 +246,9 @@ __global__ void paged_attention_kernel_TC(
...
@@ -193,7 +246,9 @@ __global__ void paged_attention_kernel_TC(
const
int
lane
=
thread_idx
%
WARP_SIZE
;
const
int
lane
=
thread_idx
%
WARP_SIZE
;
const
int
rowid
=
lane
%
16
;
const
int
rowid
=
lane
%
16
;
const
int
rows
=
lane
/
16
;
const
int
rows
=
lane
/
16
;
const
float
k_scale
=*
k_scale_ptr
;
const
float
v_scale
=*
v_scale_ptr
;
const
int
num_queries_per_kv
=
num_heads
/
num_kv_heads
;
const
int
num_queries_per_kv
=
num_heads
/
num_kv_heads
;
const
int
num_blocks_per_kv
=
((
num_queries_per_kv
+
REUSE_KV_TIMES
-
1
)
/
REUSE_KV_TIMES
);
const
int
num_blocks_per_kv
=
((
num_queries_per_kv
+
REUSE_KV_TIMES
-
1
)
/
REUSE_KV_TIMES
);
const
int
head_idx
=
(
blockIdx
.
y
/
num_blocks_per_kv
)
*
num_queries_per_kv
+
(
blockIdx
.
y
%
num_blocks_per_kv
)
*
REUSE_KV_TIMES
;
const
int
head_idx
=
(
blockIdx
.
y
/
num_blocks_per_kv
)
*
num_queries_per_kv
+
(
blockIdx
.
y
%
num_blocks_per_kv
)
*
REUSE_KV_TIMES
;
...
@@ -242,25 +297,44 @@ __global__ void paged_attention_kernel_TC(
...
@@ -242,25 +297,44 @@ __global__ void paged_attention_kernel_TC(
__shared__
float
s_max
[
REUSE_KV_TIMES
][
NUM_WARPS
];
__shared__
float
s_max
[
REUSE_KV_TIMES
][
NUM_WARPS
];
__shared__
float
s_logit
[
NUM_WARPS
];
__shared__
float
s_logit
[
NUM_WARPS
];
const
int
*
block_table
=
block_tables
+
seq_idx
*
max_num_blocks_per_seq
;
const
int
*
block_table
=
block_tables
+
seq_idx
*
max_num_blocks_per_seq
;
const
cache_t
*
k_ptr_base
=
k_cache
+
kv_head_idx
*
kv_head_stride
+
lane
*
8
;
const
cache_t
*
k_ptr_base
=
k_cache
+
kv_head_idx
*
kv_head_stride
+
lane
*
x
;
for
(
int
block_idx
=
start_block_idx
+
warp_idx
;
block_idx
<
end_block_idx
;
block_idx
+=
NUM_WARPS
)
{
for
(
int
block_idx
=
start_block_idx
+
warp_idx
;
block_idx
<
end_block_idx
;
block_idx
+=
NUM_WARPS
)
{
const
int64_t
physical_block_number
=
static_cast
<
int64_t
>
(
block_table
[
block_idx
]);
const
int64_t
physical_block_number
=
static_cast
<
int64_t
>
(
block_table
[
block_idx
]);
const
cache_t
*
k_ptr
=
k_ptr_base
+
physical_block_number
*
kv_block_stride
;
const
cache_t
*
k_ptr
=
k_ptr_base
+
physical_block_number
*
kv_block_stride
;
float4_t
qk_vec
=
{
0
,
0
,
0
,
0
};
float4_t
qk_vec
=
{
0
,
0
,
0
,
0
};
half4x2
k_vec
[
2
];
k_vec
[
0
]
=*
reinterpret_cast
<
const
half4x2
*>
(
k_ptr
);
if
constexpr
(
KV_DTYPE
==
Fp8KVCacheDataType
::
kAuto
){
#pragma unroll
half4x2
k_vec
[
2
];
for
(
int
i
=
0
;
i
<
3
;
i
++
){
k_vec
[
0
]
=*
reinterpret_cast
<
const
half4x2
*>
(
k_ptr
);
if
(
rowid
<
q_boundary
)
q_vec
=
q_vecs
[
rowid
][
i
*
4
+
rows
];
#pragma unroll
k_vec
[
1
-
i
%
2
]
=*
reinterpret_cast
<
const
half4x2
*>
(
k_ptr
+
(
i
+
1
)
*
512
);
for
(
int
i
=
0
;
i
<
3
;
i
++
){
builtin_amdgcn_mmac
<
is_half
>
(
k_vec
[
i
%
2
].
data
[
0
],
q_vec
.
data
[
0
],
qk_vec
);
if
(
rowid
<
q_boundary
)
q_vec
=
q_vecs
[
rowid
][
i
*
4
+
rows
];
builtin_amdgcn_mmac
<
is_half
>
(
k_vec
[
i
%
2
].
data
[
1
],
q_vec
.
data
[
1
],
qk_vec
);
k_vec
[
1
-
i
%
2
]
=*
reinterpret_cast
<
const
half4x2
*>
(
k_ptr
+
(
i
+
1
)
*
WARP_SIZE
*
x
);
builtin_amdgcn_mmac
<
is_half
>
(
k_vec
[
i
%
2
].
data
[
0
],
q_vec
.
data
[
0
],
qk_vec
);
builtin_amdgcn_mmac
<
is_half
>
(
k_vec
[
i
%
2
].
data
[
1
],
q_vec
.
data
[
1
],
qk_vec
);
}
//tail
{
if
(
rowid
<
q_boundary
)
q_vec
=
q_vecs
[
rowid
][
3
*
4
+
rows
];
builtin_amdgcn_mmac
<
is_half
>
(
k_vec
[
1
].
data
[
0
],
q_vec
.
data
[
0
],
qk_vec
);
v_mmac_f32_16x16x16_f16
<
is_half
>
(
k_vec
[
1
].
data
[
1
],
q_vec
.
data
[
1
],
qk_vec
);
}
}
}
//tail
else
{
{
uint8x4x4
k_quant
=*
reinterpret_cast
<
const
uint8x4x4
*>
(
k_ptr
);
if
(
rowid
<
q_boundary
)
q_vec
=
q_vecs
[
rowid
][
3
*
4
+
rows
];
if
(
rowid
<
q_boundary
)
q_vec
=
q_vecs
[
rowid
][
2
*
rows
];
builtin_amdgcn_mmac
<
is_half
>
(
k_vec
[
1
].
data
[
0
],
q_vec
.
data
[
0
],
qk_vec
);
builtin_amdgcn_mmac
<
is_half
>
(
int8x4_to_half4
<
is_half
,
is_fp8
>
(
k_quant
.
data
[
0
],
k_scale
),
q_vec
.
data
[
0
],
qk_vec
);
v_mmac_f32_16x16x16_f16
<
is_half
>
(
k_vec
[
1
].
data
[
1
],
q_vec
.
data
[
1
],
qk_vec
);
builtin_amdgcn_mmac
<
is_half
>
(
int8x4_to_half4
<
is_half
,
is_fp8
>
(
k_quant
.
data
[
1
],
k_scale
),
q_vec
.
data
[
1
],
qk_vec
);
if
(
rowid
<
q_boundary
)
q_vec
=
q_vecs
[
rowid
][
2
*
rows
+
1
];
builtin_amdgcn_mmac
<
is_half
>
(
int8x4_to_half4
<
is_half
,
is_fp8
>
(
k_quant
.
data
[
2
],
k_scale
),
q_vec
.
data
[
0
],
qk_vec
);
builtin_amdgcn_mmac
<
is_half
>
(
int8x4_to_half4
<
is_half
,
is_fp8
>
(
k_quant
.
data
[
3
],
k_scale
),
q_vec
.
data
[
1
],
qk_vec
);
k_quant
=*
reinterpret_cast
<
const
uint8x4x4
*>
(
k_ptr
+
WARP_SIZE
*
x
);
if
(
rowid
<
q_boundary
)
q_vec
=
q_vecs
[
rowid
][
2
*
rows
+
8
];
builtin_amdgcn_mmac
<
is_half
>
(
int8x4_to_half4
<
is_half
,
is_fp8
>
(
k_quant
.
data
[
0
],
k_scale
),
q_vec
.
data
[
0
],
qk_vec
);
builtin_amdgcn_mmac
<
is_half
>
(
int8x4_to_half4
<
is_half
,
is_fp8
>
(
k_quant
.
data
[
1
],
k_scale
),
q_vec
.
data
[
1
],
qk_vec
);
if
(
rowid
<
q_boundary
)
q_vec
=
q_vecs
[
rowid
][
2
*
rows
+
9
];
builtin_amdgcn_mmac
<
is_half
>
(
int8x4_to_half4
<
is_half
,
is_fp8
>
(
k_quant
.
data
[
2
],
k_scale
),
q_vec
.
data
[
0
],
qk_vec
);
v_mmac_f32_16x16x16_f16
<
is_half
>
(
int8x4_to_half4
<
is_half
,
is_fp8
>
(
k_quant
.
data
[
3
],
k_scale
),
q_vec
.
data
[
1
],
qk_vec
);
}
}
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
reuse_group
;
i
++
){
for
(
int
i
=
0
;
i
<
reuse_group
;
i
++
){
...
@@ -362,7 +436,14 @@ __global__ void paged_attention_kernel_TC(
...
@@ -362,7 +436,14 @@ __global__ void paged_attention_kernel_TC(
#pragma unroll
#pragma unroll
for
(
int
k
=
0
;
k
<
4
;
k
++
){
for
(
int
k
=
0
;
k
<
4
;
k
++
){
int
offset
=
i
*
1024
+
k
*
256
;
int
offset
=
i
*
1024
+
k
*
256
;
half4_t
v_vec
=*
reinterpret_cast
<
const
half4_t
*>
(
v_ptr
+
offset
);
half4_t
v_vec
;
if
constexpr
(
KV_DTYPE
==
Fp8KVCacheDataType
::
kAuto
){
v_vec
=*
reinterpret_cast
<
const
half4_t
*>
(
v_ptr
+
offset
);
}
else
{
uint8x4_t
quant_v
=
*
reinterpret_cast
<
const
uint8x4_t
*>
(
v_ptr
+
offset
);
v_vec
=
int8x4_to_half4
<
is_half
,
is_fp8
>
(
quant_v
,
v_scale
);
}
if
(
block_idx
==
num_seq_blocks
-
1
)
{
if
(
block_idx
==
num_seq_blocks
-
1
)
{
scalar_t
*
v_vec_ptr
=
reinterpret_cast
<
scalar_t
*>
(
&
v_vec
);
scalar_t
*
v_vec_ptr
=
reinterpret_cast
<
scalar_t
*>
(
&
v_vec
);
#pragma unroll
#pragma unroll
...
@@ -458,7 +539,14 @@ __global__ void paged_attention_kernel_TC(
...
@@ -458,7 +539,14 @@ __global__ void paged_attention_kernel_TC(
#pragma unroll
#pragma unroll
for
(
int
k
=
0
;
k
<
4
;
k
++
){
for
(
int
k
=
0
;
k
<
4
;
k
++
){
int
offset
=
i
*
1024
+
k
*
256
;
int
offset
=
i
*
1024
+
k
*
256
;
half4_t
v_vec
=*
reinterpret_cast
<
const
half4_t
*>
(
v_ptr
+
offset
);
half4_t
v_vec
;
if
constexpr
(
KV_DTYPE
==
Fp8KVCacheDataType
::
kAuto
){
v_vec
=*
reinterpret_cast
<
const
half4_t
*>
(
v_ptr
+
offset
);
}
else
{
uint8x4_t
quant_v
=
*
reinterpret_cast
<
const
uint8x4_t
*>
(
v_ptr
+
offset
);
v_vec
=
int8x4_to_half4
<
is_half
,
is_fp8
>
(
quant_v
,
v_scale
);
}
if
(
block_idx
==
num_seq_blocks
-
1
)
{
if
(
block_idx
==
num_seq_blocks
-
1
)
{
scalar_t
*
v_vec_ptr
=
reinterpret_cast
<
scalar_t
*>
(
&
v_vec
);
scalar_t
*
v_vec_ptr
=
reinterpret_cast
<
scalar_t
*>
(
&
v_vec
);
#pragma unroll
#pragma unroll
...
@@ -563,7 +651,14 @@ __global__ void paged_attention_kernel_TC(
...
@@ -563,7 +651,14 @@ __global__ void paged_attention_kernel_TC(
#pragma unroll
#pragma unroll
for
(
int
k
=
0
;
k
<
4
;
k
++
){
for
(
int
k
=
0
;
k
<
4
;
k
++
){
int
offset
=
i
*
1024
+
k
*
256
;
int
offset
=
i
*
1024
+
k
*
256
;
half4_t
v_vec
=*
reinterpret_cast
<
const
half4_t
*>
(
v_ptr
+
offset
);
half4_t
v_vec
;
if
constexpr
(
KV_DTYPE
==
Fp8KVCacheDataType
::
kAuto
){
v_vec
=*
reinterpret_cast
<
const
half4_t
*>
(
v_ptr
+
offset
);
}
else
{
uint8x4_t
quant_v
=
*
reinterpret_cast
<
const
uint8x4_t
*>
(
v_ptr
+
offset
);
v_vec
=
int8x4_to_half4
<
is_half
,
is_fp8
>
(
quant_v
,
v_scale
);
}
if
(
block_idx
==
num_seq_blocks
-
1
)
{
if
(
block_idx
==
num_seq_blocks
-
1
)
{
scalar_t
*
v_vec_ptr
=
reinterpret_cast
<
scalar_t
*>
(
&
v_vec
);
scalar_t
*
v_vec_ptr
=
reinterpret_cast
<
scalar_t
*>
(
&
v_vec
);
#pragma unroll
#pragma unroll
...
@@ -895,6 +990,7 @@ void get_numberthread_and_reuse_kv_v2(int& num_thread,int& reusekv,int& PARTITIO
...
@@ -895,6 +990,7 @@ void get_numberthread_and_reuse_kv_v2(int& num_thread,int& reusekv,int& PARTITIO
if
(
blocks
>=
150
||
batchsize
>=
16
||
qheads
>=
8
&&
(
batchsize
>=
4
||
(
max_seq_len
>=
2000
&&
max_seq_len
<
3900
)))
reusekv
=
4
;
if
(
blocks
>=
150
||
batchsize
>=
16
||
qheads
>=
8
&&
(
batchsize
>=
4
||
(
max_seq_len
>=
2000
&&
max_seq_len
<
3900
)))
reusekv
=
4
;
}
}
template
<
typename
T
,
typename
CACHE_T
,
int
BLOCK_SIZE
,
template
<
typename
T
,
typename
CACHE_T
,
int
BLOCK_SIZE
,
vllm
::
Fp8KVCacheDataType
KV_DTYPE
,
bool
IS_BLOCK_SPARSE
>
vllm
::
Fp8KVCacheDataType
KV_DTYPE
,
bool
IS_BLOCK_SPARSE
>
void
paged_attention_v2_launcher_opt_tc
(
void
paged_attention_v2_launcher_opt_tc
(
...
@@ -920,17 +1016,16 @@ void paged_attention_v2_launcher_opt_tc(
...
@@ -920,17 +1016,16 @@ void paged_attention_v2_launcher_opt_tc(
alibi_slopes
alibi_slopes
?
reinterpret_cast
<
const
float
*>
(
alibi_slopes
.
value
().
data_ptr
())
?
reinterpret_cast
<
const
float
*>
(
alibi_slopes
.
value
().
data_ptr
())
:
nullptr
;
:
nullptr
;
T
*
out_ptr
=
reinterpret_cast
<
T
*>
(
out
.
data_ptr
());
T
*
out_ptr
=
reinterpret_cast
<
T
*>
(
out
.
data_ptr
());
const
float
*
k_scale_ptr
=
reinterpret_cast
<
const
float
*>
(
k_scale
.
data_ptr
());
const
float
*
k_scale_ptr
=
reinterpret_cast
<
const
float
*>
(
k_scale
.
data_ptr
());
const
float
*
v_scale_ptr
=
reinterpret_cast
<
const
float
*>
(
v_scale
.
data_ptr
());
const
float
*
v_scale_ptr
=
reinterpret_cast
<
const
float
*>
(
v_scale
.
data_ptr
());
// float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr());
// float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr());
// T* tmp_out_ptr = reinterpret_cast<T*>(tmp_out.data_ptr());
T
*
query_ptr
=
reinterpret_cast
<
T
*>
(
query
.
data_ptr
());
T
*
query_ptr
=
reinterpret_cast
<
T
*>
(
query
.
data_ptr
());
CACHE_T
*
key_cache_ptr
=
reinterpret_cast
<
CACHE_T
*>
(
key_cache
.
data_ptr
());
CACHE_T
*
key_cache_ptr
=
reinterpret_cast
<
CACHE_T
*>
(
key_cache
.
data_ptr
());
CACHE_T
*
value_cache_ptr
=
reinterpret_cast
<
CACHE_T
*>
(
value_cache
.
data_ptr
());
CACHE_T
*
value_cache_ptr
=
reinterpret_cast
<
CACHE_T
*>
(
value_cache
.
data_ptr
());
int
*
block_tables_ptr
=
block_tables
.
data_ptr
<
int
>
();
int
*
block_tables_ptr
=
block_tables
.
data_ptr
<
int
>
();
int
*
seq_lens_ptr
=
seq_lens
.
data_ptr
<
int
>
();
int
*
seq_lens_ptr
=
seq_lens
.
data_ptr
<
int
>
();
static
float
*
exp_sums_ptr
=
nullptr
;
static
float
*
exp_sums_ptr
=
nullptr
;
static
float
*
max_logits_ptr
=
nullptr
;
static
float
*
max_logits_ptr
=
nullptr
;
static
T
*
tmp_out_ptr
=
nullptr
;
static
T
*
tmp_out_ptr
=
nullptr
;
...
@@ -943,7 +1038,7 @@ void paged_attention_v2_launcher_opt_tc(
...
@@ -943,7 +1038,7 @@ void paged_attention_v2_launcher_opt_tc(
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
dim3
reduce_grid
(
num_heads
,
num_seqs
);
dim3
reduce_grid
(
num_heads
,
num_seqs
);
if
constexpr
(
BLOCK_SIZE
==
16
&&
IS_BLOCK_SPARSE
==
false
&&
sizeof
(
T
)
==
2
&&
KV_DTYPE
==
vllm
::
Fp8KVCacheDataType
::
kAuto
){
if
constexpr
(
BLOCK_SIZE
==
16
&&
IS_BLOCK_SPARSE
==
false
&&
sizeof
(
T
)
==
2
){
constexpr
int
HEAD_SIZE
=
128
;
constexpr
int
HEAD_SIZE
=
128
;
int
reusekv
,
num_thread
,
max_num_partitions
,
PARTITION_SIZE
;
int
reusekv
,
num_thread
,
max_num_partitions
,
PARTITION_SIZE
;
get_numberthread_and_reuse_kv_v2
(
num_thread
,
reusekv
,
PARTITION_SIZE
,
max_num_partitions
,
num_seqs
,
max_seq_len
,
num_heads
,
num_kv_heads
,
num_blocks
);
get_numberthread_and_reuse_kv_v2
(
num_thread
,
reusekv
,
PARTITION_SIZE
,
max_num_partitions
,
num_seqs
,
max_seq_len
,
num_heads
,
num_kv_heads
,
num_blocks
);
...
...
csrc/attention/attention_with_mask_kernels_opt_tc.cu
View file @
b918400d
...
@@ -5,6 +5,7 @@
...
@@ -5,6 +5,7 @@
#include "attention_dtypes.h"
#include "attention_dtypes.h"
#include "attention_utils.cuh"
#include "attention_utils.cuh"
#include "../quantization/int8_kvcache/quant_utils.cuh"
#include <hip/hip_bf16.h>
#include <hip/hip_bf16.h>
...
@@ -71,6 +72,7 @@ inline __device__ float block_sum(float* red_smem, float sum) {
...
@@ -71,6 +72,7 @@ inline __device__ float block_sum(float* red_smem, float sum) {
return
VLLM_SHFL_SYNC
(
sum
,
0
);
return
VLLM_SHFL_SYNC
(
sum
,
0
);
}
}
using
uint8x4_t
=
__attribute__
(
(
__vector_size__
(
4
*
sizeof
(
uint8_t
))
))
uint8_t
;
using
half4_t
=
__attribute__
(
(
__vector_size__
(
4
*
sizeof
(
_Float16
))
))
_Float16
;
using
half4_t
=
__attribute__
(
(
__vector_size__
(
4
*
sizeof
(
_Float16
))
))
_Float16
;
using
v4bh
=
__attribute__
(
(
__vector_size__
(
4
*
sizeof
(
short
))
))
short
;
using
v4bh
=
__attribute__
(
(
__vector_size__
(
4
*
sizeof
(
short
))
))
short
;
using
float4_t
=
__attribute__
(
(
__vector_size__
(
4
*
sizeof
(
float
))
))
float
;
using
float4_t
=
__attribute__
(
(
__vector_size__
(
4
*
sizeof
(
float
))
))
float
;
...
@@ -78,12 +80,62 @@ using float2_t = __attribute__( (__vector_size__(2 * sizeof(float)) )) float;
...
@@ -78,12 +80,62 @@ using float2_t = __attribute__( (__vector_size__(2 * sizeof(float)) )) float;
struct
half4x2
{
struct
half4x2
{
half4_t
data
[
2
];
half4_t
data
[
2
];
};
};
struct
uint8x4x4
{
uint8x4_t
data
[
4
];
};
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
struct
vec2data
{
struct
vec2data
{
scalar_t
data
[
2
];
scalar_t
data
[
2
];
};
};
inline
__device__
float
uint82float
(
const
uint8_t
&
input
)
{
const
uint32_t
w
=
(
uint32_t
)
input
<<
24
;
const
uint32_t
sign
=
w
&
UINT32_C
(
0x80000000
);
const
uint32_t
nonsign
=
w
&
UINT32_C
(
0x7FFFFFFF
);
uint32_t
renorm_shift
=
__clz
(
nonsign
);
renorm_shift
=
renorm_shift
>
4
?
renorm_shift
-
4
:
0
;
uint32_t
result
=
sign
|
((
nonsign
<<
renorm_shift
>>
4
)
+
((
0x78
-
renorm_shift
)
<<
23
));
return
c10
::
detail
::
fp32_from_bits
(
result
);
}
template
<
bool
is_half
,
bool
is_fp8
>
inline
__device__
half4_t
int8x4_to_half4
(
uint8x4_t
x
,
const
float
scale
)
{
half4_t
ret
;
if
constexpr
(
is_fp8
){
if
constexpr
(
is_half
){
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
){
ret
[
i
]
=
uint82float
(
x
[
i
])
*
scale
;
}
}
else
{
__nv_bfloat16
*
bd
=
reinterpret_cast
<
__nv_bfloat16
*>
(
&
ret
);
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
){
bd
[
i
]
=
__float2bfloat16
(
uint82float
(
x
[
i
])
*
scale
);
}
}
}
else
{
if
constexpr
(
is_half
){
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
){
ret
[
i
]
=
(
x
[
i
]
-
128.0
f
)
*
scale
;
}
}
else
{
__nv_bfloat16
*
bd
=
reinterpret_cast
<
__nv_bfloat16
*>
(
&
ret
);
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
){
bd
[
i
]
=
__float2bfloat16
((
x
[
i
]
-
128.0
f
)
*
scale
);
}
}
}
return
ret
;
}
template
<
bool
is_half
>
template
<
bool
is_half
>
inline
__device__
void
float4_2_half4
(
half4_t
&
dst
,
const
float4_t
&
src
)
inline
__device__
void
float4_2_half4
(
half4_t
&
dst
,
const
float4_t
&
src
)
{
{
...
@@ -148,7 +200,7 @@ __global__ void paged_attention_kernel_TC_with_mask(
...
@@ -148,7 +200,7 @@ __global__ void paged_attention_kernel_TC_with_mask(
const
int
max_num_blocks_per_seq
,
const
int
max_num_blocks_per_seq
,
const
float
*
__restrict__
alibi_slopes
,
// [num_heads]
const
float
*
__restrict__
alibi_slopes
,
// [num_heads]
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
,
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
,
const
float
*
k_scale
,
const
float
*
v_scale
,
const
int
tp_rank
,
const
float
*
k_scale
_ptr
,
const
float
*
v_scale
_ptr
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
,
const
int
*
__restrict__
attn_masks
=
nullptr
,
const
int
attn_masks_stride
=
0
,
int
PARTITION_SIZE
=
0
)
{
const
int
*
__restrict__
attn_masks
=
nullptr
,
const
int
attn_masks_stride
=
0
,
int
PARTITION_SIZE
=
0
)
{
...
@@ -161,6 +213,7 @@ __global__ void paged_attention_kernel_TC_with_mask(
...
@@ -161,6 +213,7 @@ __global__ void paged_attention_kernel_TC_with_mask(
const
bool
USE_PARTITIONING
=
PARTITION_SIZE
<
num_seq_blocks
*
BLOCK_SIZE
&&
PARTITION_SIZE
>
0
;
const
bool
USE_PARTITIONING
=
PARTITION_SIZE
<
num_seq_blocks
*
BLOCK_SIZE
&&
PARTITION_SIZE
>
0
;
if
(
partition_idx
*
PARTITION_SIZE
>=
seq_len
)
return
;
if
(
partition_idx
*
PARTITION_SIZE
>=
seq_len
)
return
;
constexpr
bool
is_half
=
std
::
is_same
<
scalar_t
,
uint16_t
>::
value
;
constexpr
bool
is_half
=
std
::
is_same
<
scalar_t
,
uint16_t
>::
value
;
constexpr
bool
is_fp8
=
(
KV_DTYPE
==
Fp8KVCacheDataType
::
kFp8E4M3
);
static_assert
(
HEAD_SIZE
<=
4
*
NUM_THREADS
,
"HEAD_SIZE<=4*NUM_THREADS"
);
static_assert
(
HEAD_SIZE
<=
4
*
NUM_THREADS
,
"HEAD_SIZE<=4*NUM_THREADS"
);
const
int
num_blocks_per_partition
=
USE_PARTITIONING
?
PARTITION_SIZE
/
BLOCK_SIZE
:
num_seq_blocks
;
const
int
num_blocks_per_partition
=
USE_PARTITIONING
?
PARTITION_SIZE
/
BLOCK_SIZE
:
num_seq_blocks
;
const
int
partition_size
=
USE_PARTITIONING
?
PARTITION_SIZE
:
num_seq_blocks
*
BLOCK_SIZE
;
const
int
partition_size
=
USE_PARTITIONING
?
PARTITION_SIZE
:
num_seq_blocks
*
BLOCK_SIZE
;
...
@@ -177,7 +230,9 @@ __global__ void paged_attention_kernel_TC_with_mask(
...
@@ -177,7 +230,9 @@ __global__ void paged_attention_kernel_TC_with_mask(
const
int
lane
=
thread_idx
%
WARP_SIZE
;
const
int
lane
=
thread_idx
%
WARP_SIZE
;
const
int
rowid
=
lane
%
16
;
const
int
rowid
=
lane
%
16
;
const
int
rows
=
lane
/
16
;
const
int
rows
=
lane
/
16
;
const
float
k_scale
=*
k_scale_ptr
;
const
float
v_scale
=*
v_scale_ptr
;
const
int
num_queries_per_kv
=
num_heads
/
num_kv_heads
;
const
int
num_queries_per_kv
=
num_heads
/
num_kv_heads
;
const
int
num_blocks_per_kv
=
((
num_queries_per_kv
+
REUSE_KV_TIMES
-
1
)
/
REUSE_KV_TIMES
);
const
int
num_blocks_per_kv
=
((
num_queries_per_kv
+
REUSE_KV_TIMES
-
1
)
/
REUSE_KV_TIMES
);
const
int
head_idx
=
(
blockIdx
.
y
/
num_blocks_per_kv
)
*
num_queries_per_kv
+
(
blockIdx
.
y
%
num_blocks_per_kv
)
*
REUSE_KV_TIMES
;
const
int
head_idx
=
(
blockIdx
.
y
/
num_blocks_per_kv
)
*
num_queries_per_kv
+
(
blockIdx
.
y
%
num_blocks_per_kv
)
*
REUSE_KV_TIMES
;
...
@@ -226,25 +281,44 @@ __global__ void paged_attention_kernel_TC_with_mask(
...
@@ -226,25 +281,44 @@ __global__ void paged_attention_kernel_TC_with_mask(
__shared__
float
s_max
[
REUSE_KV_TIMES
][
NUM_WARPS
];
__shared__
float
s_max
[
REUSE_KV_TIMES
][
NUM_WARPS
];
__shared__
float
s_logit
[
NUM_WARPS
];
__shared__
float
s_logit
[
NUM_WARPS
];
const
int
*
block_table
=
block_tables
+
seq_idx
*
max_num_blocks_per_seq
;
const
int
*
block_table
=
block_tables
+
seq_idx
*
max_num_blocks_per_seq
;
const
cache_t
*
k_ptr_base
=
k_cache
+
kv_head_idx
*
kv_head_stride
+
lane
*
8
;
const
cache_t
*
k_ptr_base
=
k_cache
+
kv_head_idx
*
kv_head_stride
+
lane
*
x
;
for
(
int
block_idx
=
start_block_idx
+
warp_idx
;
block_idx
<
end_block_idx
;
block_idx
+=
NUM_WARPS
)
{
for
(
int
block_idx
=
start_block_idx
+
warp_idx
;
block_idx
<
end_block_idx
;
block_idx
+=
NUM_WARPS
)
{
const
int64_t
physical_block_number
=
static_cast
<
int64_t
>
(
block_table
[
block_idx
]);
const
int64_t
physical_block_number
=
static_cast
<
int64_t
>
(
block_table
[
block_idx
]);
const
cache_t
*
k_ptr
=
k_ptr_base
+
physical_block_number
*
kv_block_stride
;
const
cache_t
*
k_ptr
=
k_ptr_base
+
physical_block_number
*
kv_block_stride
;
float4_t
qk_vec
=
{
0
,
0
,
0
,
0
};
float4_t
qk_vec
=
{
0
,
0
,
0
,
0
};
half4x2
k_vec
[
2
];
k_vec
[
0
]
=*
reinterpret_cast
<
const
half4x2
*>
(
k_ptr
);
if
constexpr
(
KV_DTYPE
==
Fp8KVCacheDataType
::
kAuto
){
#pragma unroll
half4x2
k_vec
[
2
];
for
(
int
i
=
0
;
i
<
3
;
i
++
){
k_vec
[
0
]
=*
reinterpret_cast
<
const
half4x2
*>
(
k_ptr
);
if
(
rowid
<
q_boundary
)
q_vec
=
q_vecs
[
rowid
][
i
*
4
+
rows
];
#pragma unroll
k_vec
[
1
-
i
%
2
]
=*
reinterpret_cast
<
const
half4x2
*>
(
k_ptr
+
(
i
+
1
)
*
512
);
for
(
int
i
=
0
;
i
<
3
;
i
++
){
builtin_amdgcn_mmac
<
is_half
>
(
k_vec
[
i
%
2
].
data
[
0
],
q_vec
.
data
[
0
],
qk_vec
);
if
(
rowid
<
q_boundary
)
q_vec
=
q_vecs
[
rowid
][
i
*
4
+
rows
];
builtin_amdgcn_mmac
<
is_half
>
(
k_vec
[
i
%
2
].
data
[
1
],
q_vec
.
data
[
1
],
qk_vec
);
k_vec
[
1
-
i
%
2
]
=*
reinterpret_cast
<
const
half4x2
*>
(
k_ptr
+
(
i
+
1
)
*
WARP_SIZE
*
x
);
}
builtin_amdgcn_mmac
<
is_half
>
(
k_vec
[
i
%
2
].
data
[
0
],
q_vec
.
data
[
0
],
qk_vec
);
//tail
builtin_amdgcn_mmac
<
is_half
>
(
k_vec
[
i
%
2
].
data
[
1
],
q_vec
.
data
[
1
],
qk_vec
);
{
}
if
(
rowid
<
q_boundary
)
q_vec
=
q_vecs
[
rowid
][
3
*
4
+
rows
];
//tail
builtin_amdgcn_mmac
<
is_half
>
(
k_vec
[
1
].
data
[
0
],
q_vec
.
data
[
0
],
qk_vec
);
{
v_mmac_f32_16x16x16_f16
<
is_half
>
(
k_vec
[
1
].
data
[
1
],
q_vec
.
data
[
1
],
qk_vec
);
if
(
rowid
<
q_boundary
)
q_vec
=
q_vecs
[
rowid
][
3
*
4
+
rows
];
builtin_amdgcn_mmac
<
is_half
>
(
k_vec
[
1
].
data
[
0
],
q_vec
.
data
[
0
],
qk_vec
);
v_mmac_f32_16x16x16_f16
<
is_half
>
(
k_vec
[
1
].
data
[
1
],
q_vec
.
data
[
1
],
qk_vec
);
}
}
else
{
uint8x4x4
k_quant
=*
reinterpret_cast
<
const
uint8x4x4
*>
(
k_ptr
);
if
(
rowid
<
q_boundary
)
q_vec
=
q_vecs
[
rowid
][
2
*
rows
];
builtin_amdgcn_mmac
<
is_half
>
(
int8x4_to_half4
<
is_half
,
is_fp8
>
(
k_quant
.
data
[
0
],
k_scale
),
q_vec
.
data
[
0
],
qk_vec
);
builtin_amdgcn_mmac
<
is_half
>
(
int8x4_to_half4
<
is_half
,
is_fp8
>
(
k_quant
.
data
[
1
],
k_scale
),
q_vec
.
data
[
1
],
qk_vec
);
if
(
rowid
<
q_boundary
)
q_vec
=
q_vecs
[
rowid
][
2
*
rows
+
1
];
builtin_amdgcn_mmac
<
is_half
>
(
int8x4_to_half4
<
is_half
,
is_fp8
>
(
k_quant
.
data
[
2
],
k_scale
),
q_vec
.
data
[
0
],
qk_vec
);
builtin_amdgcn_mmac
<
is_half
>
(
int8x4_to_half4
<
is_half
,
is_fp8
>
(
k_quant
.
data
[
3
],
k_scale
),
q_vec
.
data
[
1
],
qk_vec
);
k_quant
=*
reinterpret_cast
<
const
uint8x4x4
*>
(
k_ptr
+
WARP_SIZE
*
x
);
if
(
rowid
<
q_boundary
)
q_vec
=
q_vecs
[
rowid
][
2
*
rows
+
8
];
builtin_amdgcn_mmac
<
is_half
>
(
int8x4_to_half4
<
is_half
,
is_fp8
>
(
k_quant
.
data
[
0
],
k_scale
),
q_vec
.
data
[
0
],
qk_vec
);
builtin_amdgcn_mmac
<
is_half
>
(
int8x4_to_half4
<
is_half
,
is_fp8
>
(
k_quant
.
data
[
1
],
k_scale
),
q_vec
.
data
[
1
],
qk_vec
);
if
(
rowid
<
q_boundary
)
q_vec
=
q_vecs
[
rowid
][
2
*
rows
+
9
];
builtin_amdgcn_mmac
<
is_half
>
(
int8x4_to_half4
<
is_half
,
is_fp8
>
(
k_quant
.
data
[
2
],
k_scale
),
q_vec
.
data
[
0
],
qk_vec
);
v_mmac_f32_16x16x16_f16
<
is_half
>
(
int8x4_to_half4
<
is_half
,
is_fp8
>
(
k_quant
.
data
[
3
],
k_scale
),
q_vec
.
data
[
1
],
qk_vec
);
}
}
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
reuse_group
;
i
++
){
for
(
int
i
=
0
;
i
<
reuse_group
;
i
++
){
...
@@ -353,7 +427,14 @@ __global__ void paged_attention_kernel_TC_with_mask(
...
@@ -353,7 +427,14 @@ __global__ void paged_attention_kernel_TC_with_mask(
#pragma unroll
#pragma unroll
for
(
int
k
=
0
;
k
<
4
;
k
++
){
for
(
int
k
=
0
;
k
<
4
;
k
++
){
int
offset
=
i
*
1024
+
k
*
256
;
int
offset
=
i
*
1024
+
k
*
256
;
half4_t
v_vec
=*
reinterpret_cast
<
const
half4_t
*>
(
v_ptr
+
offset
);
half4_t
v_vec
;
if
constexpr
(
KV_DTYPE
==
Fp8KVCacheDataType
::
kAuto
){
v_vec
=*
reinterpret_cast
<
const
half4_t
*>
(
v_ptr
+
offset
);
}
else
{
uint8x4_t
quant_v
=
*
reinterpret_cast
<
const
uint8x4_t
*>
(
v_ptr
+
offset
);
v_vec
=
int8x4_to_half4
<
is_half
,
is_fp8
>
(
quant_v
,
v_scale
);
}
if
(
block_idx
==
num_seq_blocks
-
1
)
{
if
(
block_idx
==
num_seq_blocks
-
1
)
{
scalar_t
*
v_vec_ptr
=
reinterpret_cast
<
scalar_t
*>
(
&
v_vec
);
scalar_t
*
v_vec_ptr
=
reinterpret_cast
<
scalar_t
*>
(
&
v_vec
);
#pragma unroll
#pragma unroll
...
@@ -449,7 +530,14 @@ __global__ void paged_attention_kernel_TC_with_mask(
...
@@ -449,7 +530,14 @@ __global__ void paged_attention_kernel_TC_with_mask(
#pragma unroll
#pragma unroll
for
(
int
k
=
0
;
k
<
4
;
k
++
){
for
(
int
k
=
0
;
k
<
4
;
k
++
){
int
offset
=
i
*
1024
+
k
*
256
;
int
offset
=
i
*
1024
+
k
*
256
;
half4_t
v_vec
=*
reinterpret_cast
<
const
half4_t
*>
(
v_ptr
+
offset
);
half4_t
v_vec
;
if
constexpr
(
KV_DTYPE
==
Fp8KVCacheDataType
::
kAuto
){
v_vec
=*
reinterpret_cast
<
const
half4_t
*>
(
v_ptr
+
offset
);
}
else
{
uint8x4_t
quant_v
=
*
reinterpret_cast
<
const
uint8x4_t
*>
(
v_ptr
+
offset
);
v_vec
=
int8x4_to_half4
<
is_half
,
is_fp8
>
(
quant_v
,
v_scale
);
}
if
(
block_idx
==
num_seq_blocks
-
1
)
{
if
(
block_idx
==
num_seq_blocks
-
1
)
{
scalar_t
*
v_vec_ptr
=
reinterpret_cast
<
scalar_t
*>
(
&
v_vec
);
scalar_t
*
v_vec_ptr
=
reinterpret_cast
<
scalar_t
*>
(
&
v_vec
);
#pragma unroll
#pragma unroll
...
@@ -554,7 +642,14 @@ __global__ void paged_attention_kernel_TC_with_mask(
...
@@ -554,7 +642,14 @@ __global__ void paged_attention_kernel_TC_with_mask(
#pragma unroll
#pragma unroll
for
(
int
k
=
0
;
k
<
4
;
k
++
){
for
(
int
k
=
0
;
k
<
4
;
k
++
){
int
offset
=
i
*
1024
+
k
*
256
;
int
offset
=
i
*
1024
+
k
*
256
;
half4_t
v_vec
=*
reinterpret_cast
<
const
half4_t
*>
(
v_ptr
+
offset
);
half4_t
v_vec
;
if
constexpr
(
KV_DTYPE
==
Fp8KVCacheDataType
::
kAuto
){
v_vec
=*
reinterpret_cast
<
const
half4_t
*>
(
v_ptr
+
offset
);
}
else
{
uint8x4_t
quant_v
=
*
reinterpret_cast
<
const
uint8x4_t
*>
(
v_ptr
+
offset
);
v_vec
=
int8x4_to_half4
<
is_half
,
is_fp8
>
(
quant_v
,
v_scale
);
}
if
(
block_idx
==
num_seq_blocks
-
1
)
{
if
(
block_idx
==
num_seq_blocks
-
1
)
{
scalar_t
*
v_vec_ptr
=
reinterpret_cast
<
scalar_t
*>
(
&
v_vec
);
scalar_t
*
v_vec_ptr
=
reinterpret_cast
<
scalar_t
*>
(
&
v_vec
);
#pragma unroll
#pragma unroll
...
@@ -831,14 +926,12 @@ void paged_attention_v2_launcher_opt_tc_with_mask(
...
@@ -831,14 +926,12 @@ void paged_attention_v2_launcher_opt_tc_with_mask(
T
*
out_ptr
=
reinterpret_cast
<
T
*>
(
out
.
data_ptr
());
T
*
out_ptr
=
reinterpret_cast
<
T
*>
(
out
.
data_ptr
());
const
float
*
k_scale_ptr
=
reinterpret_cast
<
const
float
*>
(
k_scale
.
data_ptr
());
const
float
*
k_scale_ptr
=
reinterpret_cast
<
const
float
*>
(
k_scale
.
data_ptr
());
const
float
*
v_scale_ptr
=
reinterpret_cast
<
const
float
*>
(
v_scale
.
data_ptr
());
const
float
*
v_scale_ptr
=
reinterpret_cast
<
const
float
*>
(
v_scale
.
data_ptr
());
// float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr());
// float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr());
// T* tmp_out_ptr = reinterpret_cast<T*>(tmp_out.data_ptr());
T
*
query_ptr
=
reinterpret_cast
<
T
*>
(
query
.
data_ptr
());
T
*
query_ptr
=
reinterpret_cast
<
T
*>
(
query
.
data_ptr
());
CACHE_T
*
key_cache_ptr
=
reinterpret_cast
<
CACHE_T
*>
(
key_cache
.
data_ptr
());
CACHE_T
*
key_cache_ptr
=
reinterpret_cast
<
CACHE_T
*>
(
key_cache
.
data_ptr
());
CACHE_T
*
value_cache_ptr
=
reinterpret_cast
<
CACHE_T
*>
(
value_cache
.
data_ptr
());
CACHE_T
*
value_cache_ptr
=
reinterpret_cast
<
CACHE_T
*>
(
value_cache
.
data_ptr
());
int
*
block_tables_ptr
=
block_tables
.
data_ptr
<
int
>
();
int
*
block_tables_ptr
=
block_tables
.
data_ptr
<
int
>
();
int
*
seq_lens_ptr
=
seq_lens
.
data_ptr
<
int
>
();
int
*
seq_lens_ptr
=
seq_lens
.
data_ptr
<
int
>
();
static
float
*
exp_sums_ptr
=
nullptr
;
static
float
*
exp_sums_ptr
=
nullptr
;
static
float
*
max_logits_ptr
=
nullptr
;
static
float
*
max_logits_ptr
=
nullptr
;
static
T
*
tmp_out_ptr
=
nullptr
;
static
T
*
tmp_out_ptr
=
nullptr
;
...
@@ -851,7 +944,7 @@ void paged_attention_v2_launcher_opt_tc_with_mask(
...
@@ -851,7 +944,7 @@ void paged_attention_v2_launcher_opt_tc_with_mask(
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
dim3
reduce_grid
(
num_heads
,
num_seqs
);
dim3
reduce_grid
(
num_heads
,
num_seqs
);
if
constexpr
(
BLOCK_SIZE
==
16
&&
IS_BLOCK_SPARSE
==
false
&&
sizeof
(
T
)
==
2
&&
KV_DTYPE
==
vllm
::
Fp8KVCacheDataType
::
kAuto
){
if
constexpr
(
BLOCK_SIZE
==
16
&&
IS_BLOCK_SPARSE
==
false
&&
sizeof
(
T
)
==
2
){
constexpr
int
HEAD_SIZE
=
128
;
constexpr
int
HEAD_SIZE
=
128
;
int
reusekv
,
num_thread
,
max_num_partitions
,
PARTITION_SIZE
;
int
reusekv
,
num_thread
,
max_num_partitions
,
PARTITION_SIZE
;
get_numberthread_and_reuse_kv_v2
(
num_thread
,
reusekv
,
PARTITION_SIZE
,
max_num_partitions
,
num_seqs
,
max_seq_len
,
num_heads
,
num_kv_heads
,
num_blocks
);
get_numberthread_and_reuse_kv_v2
(
num_thread
,
reusekv
,
PARTITION_SIZE
,
max_num_partitions
,
num_seqs
,
max_seq_len
,
num_heads
,
num_kv_heads
,
num_blocks
);
...
...
vllm/attention/ops/paged_attn.py
View file @
b918400d
...
@@ -130,10 +130,7 @@ class PagedAttention:
...
@@ -130,10 +130,7 @@ class PagedAttention:
# TODO(woosuk): Tune this heuristic.
# TODO(woosuk): Tune this heuristic.
# For context len > 8192, use V2 kernel to avoid shared memory shortage.
# For context len > 8192, use V2 kernel to avoid shared memory shortage.
kvquant
=
False
if
use_tc
and
head_size
==
128
:
if
(
kv_cache_dtype
==
"int8"
):
kvquant
=
True
if
use_tc
and
head_size
==
128
and
not
kvquant
:
if
envs
.
VLLM_USE_PA_PRINT_PARAM
:
if
envs
.
VLLM_USE_PA_PRINT_PARAM
:
print
(
"PA V1 SIZE:"
)
print
(
"PA V1 SIZE:"
)
print
(
f
"query.shape =
{
query
.
shape
}
, key_cache.shape =
{
key_cache
.
shape
}
, value_cache.shape =
{
value_cache
.
shape
}
"
)
print
(
f
"query.shape =
{
query
.
shape
}
, key_cache.shape =
{
key_cache
.
shape
}
, value_cache.shape =
{
value_cache
.
shape
}
"
)
...
...
vllm/utils.py
View file @
b918400d
...
@@ -77,6 +77,12 @@ logger = init_logger(__name__)
...
@@ -77,6 +77,12 @@ logger = init_logger(__name__)
gpuname
=
torch
.
cuda
.
get_device_properties
(
torch
.
cuda
.
current_device
()).
name
gpuname
=
torch
.
cuda
.
get_device_properties
(
torch
.
cuda
.
current_device
()).
name
SUPPORT_TC
=
gpuname
.
startswith
(
'K100_AI'
)
or
gpuname
.
startswith
(
'BW'
)
SUPPORT_TC
=
gpuname
.
startswith
(
'K100_AI'
)
or
gpuname
.
startswith
(
'BW'
)
def
_generate_random_int8
(
tensor
:
torch
.
Tensor
,
)
->
None
:
tensor
=
torch
.
randint
(
0
,
255
,
tensor
.
size
(),
dtype
=
torch
.
uint8
,
device
=
'cuda'
)
# Exception strings for non-implemented encoder/decoder scenarios
# Exception strings for non-implemented encoder/decoder scenarios
# Reminder: Please update docs/source/features/compatibility_matrix.md
# Reminder: Please update docs/source/features/compatibility_matrix.md
...
@@ -798,7 +804,7 @@ def create_kv_caches_with_random_flash(
...
@@ -798,7 +804,7 @@ def create_kv_caches_with_random_flash(
elif
cache_dtype
==
'fp8'
:
elif
cache_dtype
==
'fp8'
:
_generate_random_fp8
(
key_value_cache
,
-
scale
,
scale
)
_generate_random_fp8
(
key_value_cache
,
-
scale
,
scale
)
elif
cache_dtype
==
'int8'
:
elif
cache_dtype
==
'int8'
:
_generate_random_int8
(
value_cache
)
_generate_random_int8
(
key_
value_cache
)
else
:
else
:
raise
ValueError
(
raise
ValueError
(
f
"Does not support key cache of type
{
cache_dtype
}
"
)
f
"Does not support key cache of type
{
cache_dtype
}
"
)
...
@@ -841,7 +847,7 @@ def create_kv_caches_with_random(
...
@@ -841,7 +847,7 @@ def create_kv_caches_with_random(
elif
cache_dtype
==
'fp8'
:
elif
cache_dtype
==
'fp8'
:
_generate_random_fp8
(
key_cache
,
-
scale
,
scale
)
_generate_random_fp8
(
key_cache
,
-
scale
,
scale
)
elif
cache_dtype
==
'int8'
:
elif
cache_dtype
==
'int8'
:
_generate_random_int8
(
key_
value_
cache
)
_generate_random_int8
(
key_cache
)
else
:
else
:
raise
ValueError
(
raise
ValueError
(
f
"Does not support key cache of type
{
cache_dtype
}
"
)
f
"Does not support key cache of type
{
cache_dtype
}
"
)
...
@@ -858,7 +864,7 @@ def create_kv_caches_with_random(
...
@@ -858,7 +864,7 @@ def create_kv_caches_with_random(
elif
cache_dtype
==
'fp8'
:
elif
cache_dtype
==
'fp8'
:
_generate_random_fp8
(
value_cache
,
-
scale
,
scale
)
_generate_random_fp8
(
value_cache
,
-
scale
,
scale
)
elif
cache_dtype
==
'int8'
:
elif
cache_dtype
==
'int8'
:
_generate_random_int8
(
key
_cache
)
_generate_random_int8
(
value
_cache
)
else
:
else
:
raise
ValueError
(
raise
ValueError
(
f
"Does not support value cache of type
{
cache_dtype
}
"
)
f
"Does not support value cache of type
{
cache_dtype
}
"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment