Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
cc6d659f
Commit
cc6d659f
authored
Jan 12, 2025
by
Po Yen, Chen
Browse files
Re-format interface sources
parent
5a683756
Changes
3
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
279 additions
and
302 deletions
+279
-302
example/ck_tile/18_paged_attention/include/paged_attention.hpp
...le/ck_tile/18_paged_attention/include/paged_attention.hpp
+21
-17
example/ck_tile/18_paged_attention/itfs/paged_attention.cpp
example/ck_tile/18_paged_attention/itfs/paged_attention.cpp
+190
-214
example/ck_tile/18_paged_attention/py_itfs/paged_attention.cu
...ple/ck_tile/18_paged_attention/py_itfs/paged_attention.cu
+68
-71
No files found.
example/ck_tile/18_paged_attention/include/paged_attention.hpp
View file @
cc6d659f
...
@@ -5,37 +5,43 @@
...
@@ -5,37 +5,43 @@
#include <hip/hip_runtime.h>
#include <hip/hip_runtime.h>
namespace
native
{
namespace
native
{
enum
class
ScalarType
{
enum
class
ScalarType
{
Half
,
Half
,
BFloat16
,
BFloat16
,
};
};
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
stream
,
ScalarType
scalar_type
)
{
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
stream
,
ScalarType
scalar_type
)
switch
(
scalar_type
)
{
{
switch
(
scalar_type
)
{
case
ScalarType
::
Half
:
stream
<<
"Half"
;
break
;
case
ScalarType
::
Half
:
stream
<<
"Half"
;
break
;
case
ScalarType
::
BFloat16
:
stream
<<
"BFloat16"
;
break
;
case
ScalarType
::
BFloat16
:
stream
<<
"BFloat16"
;
break
;
}
}
return
stream
;
return
stream
;
}
}
enum
class
Fp8KVCacheDataType
{
enum
class
Fp8KVCacheDataType
kAuto
=
0
,
{
kFp8E4M3
=
1
,
kAuto
=
0
,
kFp8E5M2
=
2
,
kFp8E4M3
=
1
,
kFp8E5M2
=
2
,
};
};
struct
paged_attention_traits
{
struct
paged_attention_traits
{
ScalarType
q_type
;
ScalarType
q_type
;
std
::
string
kv_cache_dtype
;
std
::
string
kv_cache_dtype
;
};
};
struct
paged_attention_args
{
struct
paged_attention_args
{
int
head_size
;
int
head_size
;
int
num_seqs
;
int
num_seqs
;
int
num_heads
;
int
num_heads
;
int
num_kv_heads
;
int
num_kv_heads
;
int
max_num_blocks_per_seq
;
int
max_num_blocks_per_seq
;
int
q_stride
;
int
q_stride
;
int
kv_block_stride
;
int
kv_block_stride
;
...
@@ -63,9 +69,7 @@ struct paged_attention_args {
...
@@ -63,9 +69,7 @@ struct paged_attention_args {
int64_t
partition_size
;
int64_t
partition_size
;
};
};
void
paged_attention
(
void
paged_attention
(
const
paged_attention_traits
&
traits
,
const
paged_attention_traits
&
traits
,
const
paged_attention_args
&
args
,
const
paged_attention_args
&
args
,
hipStream_t
stream
);
hipStream_t
stream
}
// namespace native
);
\ No newline at end of file
}
\ No newline at end of file
example/ck_tile/18_paged_attention/itfs/paged_attention.cpp
View file @
cc6d659f
This diff is collapsed.
Click to expand it.
example/ck_tile/18_paged_attention/py_itfs/paged_attention.cu
View file @
cc6d659f
...
@@ -25,76 +25,73 @@ void paged_attention(
...
@@ -25,76 +25,73 @@ void paged_attention(
torch
::
Tensor
&
out
,
// [num_seqs, num_heads, head_size]
torch
::
Tensor
&
out
,
// [num_seqs, num_heads, head_size]
torch
::
Tensor
&
exp_sums
,
// [num_seqs, num_heads, max_num_partitions]
torch
::
Tensor
&
exp_sums
,
// [num_seqs, num_heads, max_num_partitions]
torch
::
Tensor
&
max_logits
,
// [num_seqs, num_heads, max_num_partitions]
torch
::
Tensor
&
max_logits
,
// [num_seqs, num_heads, max_num_partitions]
torch
::
Tensor
&
torch
::
Tensor
&
tmp_out
,
// [num_seqs, num_heads, max_num_partitions, head_size]
tmp_out
,
// [num_seqs, num_heads, max_num_partitions, head_size]
torch
::
Tensor
&
query
,
// [num_seqs, num_heads, head_size]
torch
::
Tensor
&
query
,
// [num_seqs, num_heads, head_size]
torch
::
Tensor
&
key_cache
,
// [num_blocks, num_heads, head_size/x, block_size, x]
torch
::
Tensor
&
torch
::
Tensor
&
value_cache
,
// [num_blocks, num_heads, head_size, block_size]
key_cache
,
// [num_blocks, num_heads, head_size/x, block_size, x]
int64_t
num_kv_heads
,
torch
::
Tensor
&
double
scale
,
value_cache
,
// [num_blocks, num_heads, head_size, block_size]
torch
::
Tensor
&
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
int64_t
num_kv_heads
,
double
scale
,
torch
::
Tensor
&
context_lens
,
// [num_seqs]
torch
::
Tensor
&
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
int64_t
block_size
,
torch
::
Tensor
&
context_lens
,
// [num_seqs]
int64_t
max_context_len
,
int64_t
block_size
,
int64_t
max_context_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
std
::
string
&
kv_cache_dtype
,
double
k_scale
,
double
v_scale
,
const
std
::
string
&
kv_cache_dtype
,
const
c10
::
optional
<
torch
::
Tensor
>&
fp8_out_scale
,
int64_t
partition_size
)
{
double
k_scale
,
double
v_scale
,
native
::
paged_attention_traits
traits
;
const
c10
::
optional
<
torch
::
Tensor
>&
fp8_out_scale
,
int64_t
partition_size
)
traits
.
q_type
=
(
{
query
.
dtype
()
==
at
::
ScalarType
::
Half
?
native
::
ScalarType
::
Half
:
native
::
ScalarType
::
BFloat16
native
::
paged_attention_traits
traits
;
);
traits
.
kv_cache_dtype
=
kv_cache_dtype
;
traits
.
q_type
=
(
query
.
dtype
()
==
at
::
ScalarType
::
Half
?
native
::
ScalarType
::
Half
:
native
::
ScalarType
::
BFloat16
);
native
::
paged_attention_args
args
;
traits
.
kv_cache_dtype
=
kv_cache_dtype
;
args
.
head_size
=
query
.
size
(
2
);
native
::
paged_attention_args
args
;
args
.
num_seqs
=
query
.
size
(
0
);
args
.
head_size
=
query
.
size
(
2
);
args
.
num_heads
=
query
.
size
(
1
);
args
.
head_size
=
query
.
size
(
2
);
args
.
num_seqs
=
query
.
size
(
0
);
args
.
max_num_blocks_per_seq
=
block_tables
.
size
(
1
);
args
.
num_heads
=
query
.
size
(
1
);
args
.
q_stride
=
query
.
stride
(
0
);
args
.
head_size
=
query
.
size
(
2
);
args
.
kv_block_stride
=
key_cache
.
stride
(
0
);
args
.
max_num_blocks_per_seq
=
block_tables
.
size
(
1
);
args
.
kv_head_stride
=
key_cache
.
stride
(
1
);
args
.
q_stride
=
query
.
stride
(
0
);
args
.
kv_block_stride
=
key_cache
.
stride
(
0
);
// NOTE: alibi_slopes is optional.
args
.
kv_head_stride
=
key_cache
.
stride
(
1
);
args
.
alibi_slopes_ptr
=
alibi_slopes
// NOTE: alibi_slopes is optional.
?
reinterpret_cast
<
const
float
*>
(
alibi_slopes
.
value
().
data_ptr
())
args
.
alibi_slopes_ptr
=
:
nullptr
;
alibi_slopes
?
reinterpret_cast
<
const
float
*>
(
alibi_slopes
.
value
().
data_ptr
())
:
nullptr
;
args
.
exp_sums_ptr
=
reinterpret_cast
<
float
*>
(
exp_sums
.
data_ptr
());
args
.
exp_sums_ptr
=
reinterpret_cast
<
float
*>
(
exp_sums
.
data_ptr
());
args
.
max_logits_ptr
=
reinterpret_cast
<
float
*>
(
max_logits
.
data_ptr
());
args
.
max_logits_ptr
=
reinterpret_cast
<
float
*>
(
max_logits
.
data_ptr
());
args
.
tmp_out_ptr
=
tmp_out
.
data_ptr
();
args
.
tmp_out_ptr
=
tmp_out
.
data_ptr
();
args
.
query_ptr
=
query
.
data_ptr
();
args
.
query_ptr
=
query
.
data_ptr
();
args
.
key_cache_ptr
=
key_cache
.
data_ptr
();
args
.
key_cache_ptr
=
key_cache
.
data_ptr
();
args
.
value_cache_ptr
=
value_cache
.
data_ptr
();
args
.
value_cache_ptr
=
value_cache
.
data_ptr
();
args
.
block_tables_ptr
=
block_tables
.
data_ptr
<
int
>
();
args
.
block_tables_ptr
=
block_tables
.
data_ptr
<
int
>
();
args
.
context_lens_ptr
=
context_lens
.
data_ptr
<
int
>
();
args
.
context_lens_ptr
=
context_lens
.
data_ptr
<
int
>
();
// NOTE: fp8_out_scale is optional.
// NOTE: fp8_out_scale is optional.
args
.
fp8_out_scale_ptr
=
args
.
fp8_out_scale_ptr
=
fp8_out_scale
fp8_out_scale
?
reinterpret_cast
<
const
float
*>
(
fp8_out_scale
.
value
().
data_ptr
())
:
nullptr
;
?
reinterpret_cast
<
const
float
*>
(
fp8_out_scale
.
value
().
data_ptr
())
args
.
out_ptr
=
out
.
data_ptr
();
:
nullptr
;
args
.
out_ptr
=
out
.
data_ptr
();
args
.
block_size
=
block_size
;
args
.
block_size
=
block_size
;
args
.
max_context_len
=
max_context_len
;
args
.
num_kv_heads
=
num_kv_heads
;
args
.
max_context_len
=
max_context_len
;
args
.
partition_size
=
partition_size
;
args
.
num_kv_heads
=
num_kv_heads
;
args
.
scale
=
scale
;
args
.
partition_size
=
partition_size
;
args
.
k_scale
=
k_scale
;
args
.
scale
=
scale
;
args
.
v_scale
=
v_scale
;
args
.
k_scale
=
k_scale
;
args
.
v_scale
=
v_scale
;
hipStream_t
stream
=
nullptr
;
HIP_CHECK_ERROR
(
hipStreamCreate
(
&
stream
));
hipStream_t
stream
=
nullptr
;
HIP_CHECK_ERROR
(
hipStreamCreate
(
&
stream
));
native
::
paged_attention
(
traits
,
args
,
stream
);
native
::
paged_attention
(
traits
,
args
,
stream
);
HIP_CHECK_ERROR
(
hipStreamDestroy
(
stream
));
HIP_CHECK_ERROR
(
hipStreamDestroy
(
stream
));
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment