Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
d8154515
Commit
d8154515
authored
Jul 27, 2022
by
ltqin
Browse files
code regular
parent
a085b740
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
17 additions
and
28 deletions
+17
-28
example/01_gemm/gemm_xdl_fp16_flash_attention.cpp
example/01_gemm/gemm_xdl_fp16_flash_attention.cpp
+3
-3
include/ck/tensor_operation/gpu/block/blockwise_softmax_v1.hpp
...de/ck/tensor_operation/gpu/block/blockwise_softmax_v1.hpp
+14
-25
No files found.
example/01_gemm/gemm_xdl_fp16_flash_attention.cpp
View file @
d8154515
...
...
@@ -51,7 +51,7 @@ using DeviceGemmInstance0 = ck::tensor_operation::device::DeviceGemmXdl
//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
<
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
ALayout
,
BLayout
,
CLayout
,
AElementOp
,
BElementOp
,
CElementOp
,
GemmDefault
,
128
,
32
,
32
,
4
,
8
,
16
,
16
,
1
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
;
<
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
ALayout
,
BLayout
,
CLayout
,
AElementOp
,
BElementOp
,
CElementOp
,
GemmDefault
,
64
,
16
,
16
,
4
,
8
,
16
,
16
,
1
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
;
// clang-format on
using
DeviceGemmInstance1
=
ck
::
tensor_operation
::
device
::
DeviceGemm_Xdl_CShuffle
...
...
@@ -92,8 +92,8 @@ int main(int argc, char* argv[])
bool
time_kernel
=
false
;
// GEMM shape
ck
::
index_t
M
=
32
;
ck
::
index_t
N
=
32
;
ck
::
index_t
M
=
16
;
ck
::
index_t
N
=
16
;
ck
::
index_t
K
=
64
;
ck
::
index_t
StrideA
=
K
;
...
...
include/ck/tensor_operation/gpu/block/blockwise_softmax_v1.hpp
View file @
d8154515
...
...
@@ -36,7 +36,6 @@ struct BlockwiseSoftmax_V1
template
<
typename
TopIdx
>
__host__
__device__
static
constexpr
auto
CalculateBottomIndex
(
const
TopIdx
&
idx_top
)
{
const
auto
index
=
idx_top
[
I0
];
const
auto
m
=
(
index
/
WaveSize
)
*
MPerXDL
+
index
%
MPerXDL
;
const
auto
k
=
(
index
%
WaveSize
)
/
MPerXDL
;
...
...
@@ -101,12 +100,12 @@ struct BlockwiseSoftmax_V1
auto
softmax_lds_buffer
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
AccDataType
*>
(
p_softmax
),
MPerBlock
*
2
);
//
static auto lds_buffer_m_k = GetSpaceForPreMax();
//
thread id map to thread layout
const
index_t
thread_local_id
=
get_thread_local_1d_id
();
const
auto
thread_cluster_idx
=
BlockToMKMap_M0_K_M1Adapt
::
CalculateBottomIndex
(
make_multi_index
(
thread_local_id
));
const
auto
thread_m_cluster_id
=
thread_cluster_idx
[
Number
<
0
>
{}];
//
const auto thread_k_cluster_id = thread_cluster_idx[Number<1>{}];
const
auto
thread_k_cluster_id
=
thread_cluster_idx
[
Number
<
1
>
{}];
//
// find max value
//
...
...
@@ -123,17 +122,15 @@ struct BlockwiseSoftmax_V1
ThreadwiseMaxReduce
::
Reduce
(
xdlops_out
.
template
AsType
<
float
>(),
max_value_buf
);
});
//{const index_t thread_local_id = get_thread_local_1d_id();
// printf("thread id: %d, Max: %f\t\t",thread_local_id,max_value_buf[I0]);
// ignore = p_reduce_work_buffer;}
// block reduce for max
BlockwiseMaxReduce
::
Reduce
(
reduce_work_buf
,
max_value_buf
(
I0
));
block_sync_lds
();
// save max value
softmax_lds_buffer
(
softmax_buf_desc_m_k
.
CalculateOffset
(
make_tuple
(
thread_m_cluster_id
,
1
)))
=
max_value_buf
(
I0
);
printf
(
"thread id: %d, Max: %f
\t\t
"
,
thread_local_id
,
max_value_buf
[
I0
]);
if
(
0
==
thread_k_cluster_id
)
{
softmax_lds_buffer
(
softmax_buf_desc_m_k
.
CalculateOffset
(
make_tuple
(
thread_m_cluster_id
,
1
)))
=
max_value_buf
(
I0
);
}
//
// softmax
...
...
@@ -143,7 +140,7 @@ struct BlockwiseSoftmax_V1
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
accu_value_buf
(
I
)
=
reduce
::
Add
::
template
GetIdentityValue
<
AccDataType
>();
});
// calculate exp for elements
// calculate exp for elements
, P=exp(s-max)
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n
)
{
constexpr
index_t
in_offset
=
in_thread_desc
.
CalculateOffset
(
make_tuple
(
0
,
n
,
0
));
auto
&
xdlops_out
=
in_thread_buf
.
GetVectorTypeReference
(
Number
<
in_offset
>
{});
...
...
@@ -164,19 +161,11 @@ struct BlockwiseSoftmax_V1
block_sync_lds
();
// save sum
softmax_lds_buffer
(
softmax_buf_desc_m_k
.
CalculateOffset
(
make_tuple
(
thread_m_cluster_id
,
0
)))
=
accu_value_buf
(
I0
);
// change elements
/* static_for<0, NRepeat, 1>{}([&](auto n) {
constexpr index_t in_offset = in_thread_desc.CalculateOffset(make_tuple(0, n, 0));
auto& xdlops_out =
in_thread_buf.GetVectorTypeReference(Number<in_offset>{});
static_for<0, RegSizePerXdlops, 1>{}([&](auto iK) {
xdlops_out.template AsType<float>()(iK) =
xdlops_out.template AsType<float>()[iK] / accu_value_buf(I0);
});
});*/
if
(
0
==
thread_k_cluster_id
)
{
softmax_lds_buffer
(
softmax_buf_desc_m_k
.
CalculateOffset
(
make_tuple
(
thread_m_cluster_id
,
0
)))
=
accu_value_buf
(
I0
);
}
}
};
// namespace ck
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment