Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
cc6d659f
"configs/vscode:/vscode.git/clone" did not exist on "b9b145c3350f5bbdfb8676f866a9c5a0459dd3ce"
Commit
cc6d659f
authored
Jan 12, 2025
by
Po Yen, Chen
Browse files
Re-format interface sources
parent
5a683756
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
279 additions
and
302 deletions
+279
-302
example/ck_tile/18_paged_attention/include/paged_attention.hpp
...le/ck_tile/18_paged_attention/include/paged_attention.hpp
+21
-17
example/ck_tile/18_paged_attention/itfs/paged_attention.cpp
example/ck_tile/18_paged_attention/itfs/paged_attention.cpp
+190
-214
example/ck_tile/18_paged_attention/py_itfs/paged_attention.cu
...ple/ck_tile/18_paged_attention/py_itfs/paged_attention.cu
+68
-71
No files found.
example/ck_tile/18_paged_attention/include/paged_attention.hpp
View file @
cc6d659f
...
@@ -5,31 +5,37 @@
...
@@ -5,31 +5,37 @@
#include <hip/hip_runtime.h>
#include <hip/hip_runtime.h>
namespace
native
{
namespace
native
{
enum
class
ScalarType
{
enum
class
ScalarType
{
Half
,
Half
,
BFloat16
,
BFloat16
,
};
};
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
stream
,
ScalarType
scalar_type
)
{
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
stream
,
ScalarType
scalar_type
)
switch
(
scalar_type
)
{
{
switch
(
scalar_type
)
{
case
ScalarType
::
Half
:
stream
<<
"Half"
;
break
;
case
ScalarType
::
Half
:
stream
<<
"Half"
;
break
;
case
ScalarType
::
BFloat16
:
stream
<<
"BFloat16"
;
break
;
case
ScalarType
::
BFloat16
:
stream
<<
"BFloat16"
;
break
;
}
}
return
stream
;
return
stream
;
}
}
enum
class
Fp8KVCacheDataType
{
enum
class
Fp8KVCacheDataType
{
kAuto
=
0
,
kAuto
=
0
,
kFp8E4M3
=
1
,
kFp8E4M3
=
1
,
kFp8E5M2
=
2
,
kFp8E5M2
=
2
,
};
};
struct
paged_attention_traits
{
struct
paged_attention_traits
{
ScalarType
q_type
;
ScalarType
q_type
;
std
::
string
kv_cache_dtype
;
std
::
string
kv_cache_dtype
;
};
};
struct
paged_attention_args
{
struct
paged_attention_args
{
int
head_size
;
int
head_size
;
int
num_seqs
;
int
num_seqs
;
...
@@ -63,9 +69,7 @@ struct paged_attention_args {
...
@@ -63,9 +69,7 @@ struct paged_attention_args {
int64_t
partition_size
;
int64_t
partition_size
;
};
};
void
paged_attention
(
void
paged_attention
(
const
paged_attention_traits
&
traits
,
const
paged_attention_traits
&
traits
,
const
paged_attention_args
&
args
,
const
paged_attention_args
&
args
,
hipStream_t
stream
hipStream_t
stream
);
);
}
// namespace native
}
\ No newline at end of file
\ No newline at end of file
example/ck_tile/18_paged_attention/itfs/paged_attention.cpp
View file @
cc6d659f
...
@@ -22,28 +22,55 @@
...
@@ -22,28 +22,55 @@
#include "kernel/paged_attention_kernel.hpp"
#include "kernel/paged_attention_kernel.hpp"
#define LAUNCH_CUSTOM_ATTENTION(GQA_RATIO) \
#define LAUNCH_CUSTOM_ATTENTION(GQA_RATIO) \
paged_attention_ll4mi_QKV_kernel<T, KVT, KV_DTYPE, OUTT, BLOCK_SIZE, \
paged_attention_ll4mi_QKV_kernel<T, \
HEAD_SIZE, NTHR, GQA_RATIO> \
KVT, \
<<<grid, block, 0, stream>>>( \
KV_DTYPE, \
query_ptr, key_cache_ptr, value_cache_ptr, args.num_kv_heads, args.scale, \
OUTT, \
args.block_tables_ptr, args.context_lens_ptr, args.max_num_blocks_per_seq, \
BLOCK_SIZE, \
args.alibi_slopes_ptr, args.q_stride, args.kv_block_stride, args.kv_head_stride, \
HEAD_SIZE, \
args.exp_sums_ptr, args.max_logits_ptr, tmp_out_ptr, out_ptr, max_ctx_blocks, \
NTHR, \
args.k_scale, args.v_scale, args.fp8_out_scale_ptr);
GQA_RATIO> \
<<<grid, block, 0, stream>>>(query_ptr, \
key_cache_ptr, \
value_cache_ptr, \
args.num_kv_heads, \
args.scale, \
args.block_tables_ptr, \
args.context_lens_ptr, \
args.max_num_blocks_per_seq, \
args.alibi_slopes_ptr, \
args.q_stride, \
args.kv_block_stride, \
args.kv_head_stride, \
args.exp_sums_ptr, \
args.max_logits_ptr, \
tmp_out_ptr, \
out_ptr, \
max_ctx_blocks, \
args.k_scale, \
args.v_scale, \
args.fp8_out_scale_ptr);
#define LAUNCH_CUSTOM_REDUCTION(NPAR_LOOPS) \
#define LAUNCH_CUSTOM_REDUCTION(NPAR_LOOPS) \
paged_attention_ll4mi_reduce_kernel<T, OUTT, HEAD_SIZE, HEAD_SIZE, \
paged_attention_ll4mi_reduce_kernel<T, OUTT, HEAD_SIZE, HEAD_SIZE, PARTITION_SIZE, NPAR_LOOPS> \
PARTITION_SIZE, NPAR_LOOPS> \
<<<reduce_grid, reduce_block, 0, stream>>>(out_ptr, \
<<<reduce_grid, reduce_block, 0, stream>>>( \
args.exp_sums_ptr, \
out_ptr, args.exp_sums_ptr, args.max_logits_ptr, tmp_out_ptr, \
args.max_logits_ptr, \
args.context_lens_ptr, max_num_partitions, args.fp8_out_scale_ptr);
tmp_out_ptr, \
args.context_lens_ptr, \
max_num_partitions, \
args.fp8_out_scale_ptr);
namespace
{
namespace
{
template
<
typename
T
,
typename
KVT
,
vllm
::
Fp8KVCacheDataType
KV_DTYPE
,
template
<
typename
T
,
int
BLOCK_SIZE
,
int
HEAD_SIZE
,
typename
OUTT
,
int
PARTITION_SIZE
>
typename
KVT
,
void
paged_attention_custom_launcher
(
vllm
::
Fp8KVCacheDataType
KV_DTYPE
,
const
native
::
paged_attention_args
&
args
,
int
BLOCK_SIZE
,
hipStream_t
stream
)
{
int
HEAD_SIZE
,
typename
OUTT
,
int
PARTITION_SIZE
>
void
paged_attention_custom_launcher
(
const
native
::
paged_attention_args
&
args
,
hipStream_t
stream
)
{
T
*
tmp_out_ptr
=
reinterpret_cast
<
T
*>
(
args
.
tmp_out_ptr
);
T
*
tmp_out_ptr
=
reinterpret_cast
<
T
*>
(
args
.
tmp_out_ptr
);
T
*
query_ptr
=
reinterpret_cast
<
T
*>
(
args
.
query_ptr
);
T
*
query_ptr
=
reinterpret_cast
<
T
*>
(
args
.
query_ptr
);
...
@@ -52,8 +79,7 @@ void paged_attention_custom_launcher(
...
@@ -52,8 +79,7 @@ void paged_attention_custom_launcher(
OUTT
*
out_ptr
=
reinterpret_cast
<
OUTT
*>
(
args
.
out_ptr
);
OUTT
*
out_ptr
=
reinterpret_cast
<
OUTT
*>
(
args
.
out_ptr
);
const
int
max_ctx_blocks
=
DIVIDE_ROUND_UP
(
args
.
max_context_len
,
BLOCK_SIZE
);
const
int
max_ctx_blocks
=
DIVIDE_ROUND_UP
(
args
.
max_context_len
,
BLOCK_SIZE
);
const
int
max_num_partitions
=
const
int
max_num_partitions
=
DIVIDE_ROUND_UP
(
args
.
max_context_len
,
PARTITION_SIZE
);
DIVIDE_ROUND_UP
(
args
.
max_context_len
,
PARTITION_SIZE
);
const
int
gqa_ratio
=
args
.
num_heads
/
args
.
num_kv_heads
;
const
int
gqa_ratio
=
args
.
num_heads
/
args
.
num_kv_heads
;
assert
(
args
.
num_heads
%
args
.
num_kv_heads
==
0
);
assert
(
args
.
num_heads
%
args
.
num_kv_heads
==
0
);
assert
(
args
.
head_size
==
HEAD_SIZE
);
assert
(
args
.
head_size
==
HEAD_SIZE
);
...
@@ -62,58 +88,25 @@ void paged_attention_custom_launcher(
...
@@ -62,58 +88,25 @@ void paged_attention_custom_launcher(
dim3
grid
(
args
.
num_seqs
,
max_num_partitions
,
args
.
num_kv_heads
);
dim3
grid
(
args
.
num_seqs
,
max_num_partitions
,
args
.
num_kv_heads
);
dim3
block
(
NTHR
);
dim3
block
(
NTHR
);
switch
(
gqa_ratio
)
{
switch
(
gqa_ratio
)
case
1
:
{
LAUNCH_CUSTOM_ATTENTION
(
1
);
case
1
:
LAUNCH_CUSTOM_ATTENTION
(
1
);
break
;
break
;
case
2
:
LAUNCH_CUSTOM_ATTENTION
(
2
);
break
;
case
2
:
case
3
:
LAUNCH_CUSTOM_ATTENTION
(
3
);
break
;
LAUNCH_CUSTOM_ATTENTION
(
2
);
case
4
:
LAUNCH_CUSTOM_ATTENTION
(
4
);
break
;
break
;
case
5
:
LAUNCH_CUSTOM_ATTENTION
(
5
);
break
;
case
3
:
case
6
:
LAUNCH_CUSTOM_ATTENTION
(
6
);
break
;
LAUNCH_CUSTOM_ATTENTION
(
3
);
case
7
:
LAUNCH_CUSTOM_ATTENTION
(
7
);
break
;
break
;
case
8
:
LAUNCH_CUSTOM_ATTENTION
(
8
);
break
;
case
4
:
case
9
:
LAUNCH_CUSTOM_ATTENTION
(
9
);
break
;
LAUNCH_CUSTOM_ATTENTION
(
4
);
case
10
:
LAUNCH_CUSTOM_ATTENTION
(
10
);
break
;
break
;
case
11
:
LAUNCH_CUSTOM_ATTENTION
(
11
);
break
;
case
5
:
case
12
:
LAUNCH_CUSTOM_ATTENTION
(
12
);
break
;
LAUNCH_CUSTOM_ATTENTION
(
5
);
case
13
:
LAUNCH_CUSTOM_ATTENTION
(
13
);
break
;
break
;
case
14
:
LAUNCH_CUSTOM_ATTENTION
(
14
);
break
;
case
6
:
case
15
:
LAUNCH_CUSTOM_ATTENTION
(
15
);
break
;
LAUNCH_CUSTOM_ATTENTION
(
6
);
case
16
:
LAUNCH_CUSTOM_ATTENTION
(
16
);
break
;
break
;
default:
TORCH_CHECK
(
false
,
"Unsupported gqa ratio: "
,
gqa_ratio
);
break
;
case
7
:
LAUNCH_CUSTOM_ATTENTION
(
7
);
break
;
case
8
:
LAUNCH_CUSTOM_ATTENTION
(
8
);
break
;
case
9
:
LAUNCH_CUSTOM_ATTENTION
(
9
);
break
;
case
10
:
LAUNCH_CUSTOM_ATTENTION
(
10
);
break
;
case
11
:
LAUNCH_CUSTOM_ATTENTION
(
11
);
break
;
case
12
:
LAUNCH_CUSTOM_ATTENTION
(
12
);
break
;
case
13
:
LAUNCH_CUSTOM_ATTENTION
(
13
);
break
;
case
14
:
LAUNCH_CUSTOM_ATTENTION
(
14
);
break
;
case
15
:
LAUNCH_CUSTOM_ATTENTION
(
15
);
break
;
case
16
:
LAUNCH_CUSTOM_ATTENTION
(
16
);
break
;
default:
TORCH_CHECK
(
false
,
"Unsupported gqa ratio: "
,
gqa_ratio
);
break
;
}
}
// reduction kernel is only required if max_context_len > partition size,
// reduction kernel is only required if max_context_len > partition size,
...
@@ -121,134 +114,117 @@ void paged_attention_custom_launcher(
...
@@ -121,134 +114,117 @@ void paged_attention_custom_launcher(
// note there are cases with graphing where max_context_len is the max
// note there are cases with graphing where max_context_len is the max
// supported by graphing, not the actual max among all the sequences: in that
// supported by graphing, not the actual max among all the sequences: in that
// case reduction kernel will still run but return immediately
// case reduction kernel will still run but return immediately
if
(
args
.
max_context_len
>
PARTITION_SIZE
)
{
if
(
args
.
max_context_len
>
PARTITION_SIZE
)
{
dim3
reduce_grid
(
args
.
num_heads
,
args
.
num_seqs
);
dim3
reduce_grid
(
args
.
num_heads
,
args
.
num_seqs
);
dim3
reduce_block
(
args
.
head_size
);
dim3
reduce_block
(
args
.
head_size
);
const
int
npar_loops
=
DIVIDE_ROUND_UP
(
max_num_partitions
,
WARP_SIZE
);
const
int
npar_loops
=
DIVIDE_ROUND_UP
(
max_num_partitions
,
WARP_SIZE
);
// support upto 8*64*256=128K context length
// support upto 8*64*256=128K context length
switch
(
npar_loops
)
{
switch
(
npar_loops
)
case
1
:
{
LAUNCH_CUSTOM_REDUCTION
(
1
);
case
1
:
LAUNCH_CUSTOM_REDUCTION
(
1
);
break
;
break
;
case
2
:
LAUNCH_CUSTOM_REDUCTION
(
2
);
break
;
case
2
:
case
3
:
LAUNCH_CUSTOM_REDUCTION
(
3
);
break
;
LAUNCH_CUSTOM_REDUCTION
(
2
);
case
4
:
LAUNCH_CUSTOM_REDUCTION
(
4
);
break
;
break
;
case
5
:
LAUNCH_CUSTOM_REDUCTION
(
5
);
break
;
case
3
:
case
6
:
LAUNCH_CUSTOM_REDUCTION
(
6
);
break
;
LAUNCH_CUSTOM_REDUCTION
(
3
);
case
7
:
LAUNCH_CUSTOM_REDUCTION
(
7
);
break
;
break
;
case
8
:
LAUNCH_CUSTOM_REDUCTION
(
8
);
break
;
case
4
:
default:
TORCH_CHECK
(
false
,
"Unsupported npar_loops: "
,
npar_loops
);
break
;
LAUNCH_CUSTOM_REDUCTION
(
4
);
break
;
case
5
:
LAUNCH_CUSTOM_REDUCTION
(
5
);
break
;
case
6
:
LAUNCH_CUSTOM_REDUCTION
(
6
);
break
;
case
7
:
LAUNCH_CUSTOM_REDUCTION
(
7
);
break
;
case
8
:
LAUNCH_CUSTOM_REDUCTION
(
8
);
break
;
default:
TORCH_CHECK
(
false
,
"Unsupported npar_loops: "
,
npar_loops
);
break
;
}
}
}
}
}
}
}
}
// namespace
#define CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, \
#define CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, PSIZE) \
PSIZE) \
paged_attention_custom_launcher<T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, PSIZE>(args, \
paged_attention_custom_launcher<T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, \
stream);
PSIZE>(args, stream);
#define CALL_CUSTOM_LAUNCHER_PSIZE(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT) \
#define CALL_CUSTOM_LAUNCHER_PSIZE(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \
switch(args.partition_size) \
OUTT) \
{ \
switch (args.partition_size) { \
case 256: CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, 256); break; \
case 256: \
case 512: CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, 512); break; \
CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, 256); \
default: TORCH_CHECK(false, "Unsupported partition size: ", args.partition_size); break; \
break; \
case 512: \
CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, 512); \
break; \
default: \
TORCH_CHECK(false, "Unsupported partition size: ", args.partition_size); \
break; \
}
}
#if defined(__HIPCC__) && defined(__gfx90a__)
#if defined(__HIPCC__) && defined(__gfx90a__)
#define CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE) \
#define CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE) \
if (args.fp8_out_scale_ptr) { \
if(args.fp8_out_scale_ptr) \
{ \
TORCH_CHECK(false, "fp8 out scale unsupported for gfx90a"); \
TORCH_CHECK(false, "fp8 out scale unsupported for gfx90a"); \
} else { \
} \
else \
{ \
CALL_CUSTOM_LAUNCHER_PSIZE(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, T); \
CALL_CUSTOM_LAUNCHER_PSIZE(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, T); \
}
}
#else
#else
#define CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE) \
#define CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE) \
if (args.fp8_out_scale_ptr) { \
if(args.fp8_out_scale_ptr) \
CALL_CUSTOM_LAUNCHER_PSIZE(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \
{ \
uint8_t); \
CALL_CUSTOM_LAUNCHER_PSIZE(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, uint8_t); \
} else { \
} \
else \
{ \
CALL_CUSTOM_LAUNCHER_PSIZE(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, T); \
CALL_CUSTOM_LAUNCHER_PSIZE(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, T); \
}
}
#endif
#endif
#define CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, HEAD_SIZE) \
#define CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, HEAD_SIZE) \
switch (args.block_size) { \
switch(args.block_size) \
case 16: \
{ \
CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, 16, HEAD_SIZE); \
case 16: CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, 16, HEAD_SIZE); break; \
break; \
case 32: CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, 32, HEAD_SIZE); break; \
case 32: \
default: TORCH_CHECK(false, "Unsupported block size: ", args.block_size); break; \
CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, 32, HEAD_SIZE); \
break; \
default: \
TORCH_CHECK(false, "Unsupported block size: ", args.block_size); \
break; \
}
}
#define CALL_CUSTOM_LAUNCHER_BLK_HEAD(T, KVT, KV_DTYPE) \
#define CALL_CUSTOM_LAUNCHER_BLK_HEAD(T, KVT, KV_DTYPE) \
switch (args.head_size) { \
switch(args.head_size) \
case 64: \
{ \
CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, 64); \
case 64: CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, 64); break; \
break; \
case 128: CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, 128); break; \
case 128: \
default: TORCH_CHECK(false, "Unsupported head size: ", args.head_size); break; \
CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, 128); \
break; \
default: \
TORCH_CHECK(false, "Unsupported head size: ", args.head_size); \
break; \
}
}
namespace
native
{
namespace
native
{
void
paged_attention
(
void
paged_attention
(
const
paged_attention_traits
&
traits
,
const
paged_attention_traits
&
traits
,
const
paged_attention_args
&
args
,
const
paged_attention_args
&
args
,
hipStream_t
stream
hipStream_t
stream
)
)
{
{
if
(
traits
.
kv_cache_dtype
==
"auto"
)
{
if
(
traits
.
kv_cache_dtype
==
"auto"
)
if
(
traits
.
q_type
==
ScalarType
::
Half
)
{
{
CALL_CUSTOM_LAUNCHER_BLK_HEAD
(
_Float16
,
_Float16
,
if
(
traits
.
q_type
==
ScalarType
::
Half
)
vllm
::
Fp8KVCacheDataType
::
kAuto
);
{
}
else
if
(
traits
.
q_type
==
ScalarType
::
BFloat16
)
{
CALL_CUSTOM_LAUNCHER_BLK_HEAD
(
_Float16
,
_Float16
,
vllm
::
Fp8KVCacheDataType
::
kAuto
);
CALL_CUSTOM_LAUNCHER_BLK_HEAD
(
__hip_bfloat16
,
__hip_bfloat16
,
}
vllm
::
Fp8KVCacheDataType
::
kAuto
);
else
if
(
traits
.
q_type
==
ScalarType
::
BFloat16
)
}
else
{
{
CALL_CUSTOM_LAUNCHER_BLK_HEAD
(
__hip_bfloat16
,
__hip_bfloat16
,
vllm
::
Fp8KVCacheDataType
::
kAuto
);
}
else
{
TORCH_CHECK
(
false
,
"Unsupported data type: "
,
traits
.
q_type
);
TORCH_CHECK
(
false
,
"Unsupported data type: "
,
traits
.
q_type
);
}
}
}
else
if
(
traits
.
kv_cache_dtype
==
"fp8"
||
traits
.
kv_cache_dtype
==
"fp8_e4m3"
)
{
}
if
(
traits
.
q_type
==
ScalarType
::
Half
)
{
else
if
(
traits
.
kv_cache_dtype
==
"fp8"
||
traits
.
kv_cache_dtype
==
"fp8_e4m3"
)
CALL_CUSTOM_LAUNCHER_BLK_HEAD
(
_Float16
,
uint8_t
,
{
vllm
::
Fp8KVCacheDataType
::
kFp8E4M3
);
if
(
traits
.
q_type
==
ScalarType
::
Half
)
}
else
if
(
traits
.
q_type
==
ScalarType
::
BFloat16
)
{
{
CALL_CUSTOM_LAUNCHER_BLK_HEAD
(
__hip_bfloat16
,
uint8_t
,
CALL_CUSTOM_LAUNCHER_BLK_HEAD
(
_Float16
,
uint8_t
,
vllm
::
Fp8KVCacheDataType
::
kFp8E4M3
);
vllm
::
Fp8KVCacheDataType
::
kFp8E4M3
);
}
}
else
{
else
if
(
traits
.
q_type
==
ScalarType
::
BFloat16
)
{
CALL_CUSTOM_LAUNCHER_BLK_HEAD
(
__hip_bfloat16
,
uint8_t
,
vllm
::
Fp8KVCacheDataType
::
kFp8E4M3
);
}
else
{
TORCH_CHECK
(
false
,
"Unsupported data type: "
,
traits
.
q_type
);
TORCH_CHECK
(
false
,
"Unsupported data type: "
,
traits
.
q_type
);
}
}
}
else
{
}
else
{
TORCH_CHECK
(
false
,
"Unsupported KV cache dtype: "
,
traits
.
kv_cache_dtype
);
TORCH_CHECK
(
false
,
"Unsupported KV cache dtype: "
,
traits
.
kv_cache_dtype
);
}
}
}
}
}
}
// namespace native
example/ck_tile/18_paged_attention/py_itfs/paged_attention.cu
View file @
cc6d659f
...
@@ -25,27 +25,28 @@ void paged_attention(
...
@@ -25,27 +25,28 @@ void paged_attention(
torch
::
Tensor
&
out
,
// [num_seqs, num_heads, head_size]
torch
::
Tensor
&
out
,
// [num_seqs, num_heads, head_size]
torch
::
Tensor
&
exp_sums
,
// [num_seqs, num_heads, max_num_partitions]
torch
::
Tensor
&
exp_sums
,
// [num_seqs, num_heads, max_num_partitions]
torch
::
Tensor
&
max_logits
,
// [num_seqs, num_heads, max_num_partitions]
torch
::
Tensor
&
max_logits
,
// [num_seqs, num_heads, max_num_partitions]
torch
::
Tensor
&
torch
::
Tensor
&
tmp_out
,
// [num_seqs, num_heads, max_num_partitions, head_size]
tmp_out
,
// [num_seqs, num_heads, max_num_partitions, head_size]
torch
::
Tensor
&
query
,
// [num_seqs, num_heads, head_size]
torch
::
Tensor
&
query
,
// [num_seqs, num_heads, head_size]
torch
::
Tensor
&
torch
::
Tensor
&
key_cache
,
// [num_blocks, num_heads, head_size/x, block_size, x]
key_cache
,
// [num_blocks, num_heads, head_size/x, block_size, x]
torch
::
Tensor
&
value_cache
,
// [num_blocks, num_heads, head_size, block_size]
torch
::
Tensor
&
int64_t
num_kv_heads
,
value_cache
,
// [num_blocks, num_heads, head_size, block_size]
double
scale
,
int64_t
num_kv_heads
,
double
scale
,
torch
::
Tensor
&
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
torch
::
Tensor
&
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
torch
::
Tensor
&
context_lens
,
// [num_seqs]
torch
::
Tensor
&
context_lens
,
// [num_seqs]
int64_t
block_size
,
int64_t
max_context_len
,
int64_t
block_size
,
int64_t
max_context_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
std
::
string
&
kv_cache_dtype
,
double
k_scale
,
double
v_scale
,
const
std
::
string
&
kv_cache_dtype
,
const
c10
::
optional
<
torch
::
Tensor
>&
fp8_out_scale
,
int64_t
partition_size
)
{
double
k_scale
,
double
v_scale
,
const
c10
::
optional
<
torch
::
Tensor
>&
fp8_out_scale
,
int64_t
partition_size
)
{
native
::
paged_attention_traits
traits
;
native
::
paged_attention_traits
traits
;
traits
.
q_type
=
(
traits
.
q_type
=
(
query
.
dtype
()
==
at
::
ScalarType
::
Half
?
native
::
ScalarType
::
Half
query
.
dtype
()
==
at
::
ScalarType
::
Half
?
native
::
ScalarType
::
Half
:
native
::
ScalarType
::
BFloat16
);
:
native
::
ScalarType
::
BFloat16
);
traits
.
kv_cache_dtype
=
kv_cache_dtype
;
traits
.
kv_cache_dtype
=
kv_cache_dtype
;
native
::
paged_attention_args
args
;
native
::
paged_attention_args
args
;
...
@@ -62,9 +63,7 @@ void paged_attention(
...
@@ -62,9 +63,7 @@ void paged_attention(
// NOTE: alibi_slopes is optional.
// NOTE: alibi_slopes is optional.
args
.
alibi_slopes_ptr
=
args
.
alibi_slopes_ptr
=
alibi_slopes
alibi_slopes
?
reinterpret_cast
<
const
float
*>
(
alibi_slopes
.
value
().
data_ptr
())
:
nullptr
;
?
reinterpret_cast
<
const
float
*>
(
alibi_slopes
.
value
().
data_ptr
())
:
nullptr
;
args
.
exp_sums_ptr
=
reinterpret_cast
<
float
*>
(
exp_sums
.
data_ptr
());
args
.
exp_sums_ptr
=
reinterpret_cast
<
float
*>
(
exp_sums
.
data_ptr
());
args
.
max_logits_ptr
=
reinterpret_cast
<
float
*>
(
max_logits
.
data_ptr
());
args
.
max_logits_ptr
=
reinterpret_cast
<
float
*>
(
max_logits
.
data_ptr
());
...
@@ -77,9 +76,7 @@ void paged_attention(
...
@@ -77,9 +76,7 @@ void paged_attention(
// NOTE: fp8_out_scale is optional.
// NOTE: fp8_out_scale is optional.
args
.
fp8_out_scale_ptr
=
args
.
fp8_out_scale_ptr
=
fp8_out_scale
fp8_out_scale
?
reinterpret_cast
<
const
float
*>
(
fp8_out_scale
.
value
().
data_ptr
())
:
nullptr
;
?
reinterpret_cast
<
const
float
*>
(
fp8_out_scale
.
value
().
data_ptr
())
:
nullptr
;
args
.
out_ptr
=
out
.
data_ptr
();
args
.
out_ptr
=
out
.
data_ptr
();
args
.
block_size
=
block_size
;
args
.
block_size
=
block_size
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment