Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
635ea6a0
Commit
635ea6a0
authored
Nov 14, 2023
by
Qianfeng Zhang
Browse files
Merge branch 'mha-train-develop' into mha-train-develop-d0param
parents
81639679
2f93e26f
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
121 additions
and
66 deletions
+121
-66
include/ck/tensor_operation/gpu/block/blockwise_softmax.hpp
include/ck/tensor_operation/gpu/block/blockwise_softmax.hpp
+43
-3
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_fwd_xdl_cshuffle_v2.hpp
...ion/gpu/grid/gridwise_batched_mha_fwd_xdl_cshuffle_v2.hpp
+38
-30
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_infer_xdl_cshuffle.hpp
...tion/gpu/grid/gridwise_batched_mha_infer_xdl_cshuffle.hpp
+40
-33
No files found.
include/ck/tensor_operation/gpu/block/blockwise_softmax.hpp
View file @
635ea6a0
...
@@ -88,9 +88,8 @@ struct BlockwiseSoftmax
...
@@ -88,9 +88,8 @@ struct BlockwiseSoftmax
__host__
__device__
void
Run
(
CThreadBuffer
&
in_thread_buf
,
WorkspaceBuffer
&
reduce_work_buf
)
__host__
__device__
void
Run
(
CThreadBuffer
&
in_thread_buf
,
WorkspaceBuffer
&
reduce_work_buf
)
{
{
// find max value
// find max value
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
I
)
{
static_for
<
0
,
MRepeat
,
1
>
{}(
max_value_buf
(
I
)
=
reduce
::
Max
::
template
GetIdentityValue
<
AccDataType
>();
[
&
](
auto
I
)
{
max_value_buf
(
I
)
=
ck
::
NumericLimits
<
AccDataType
>::
Lowest
();
});
});
ThreadwiseMaxReduce
::
Reduce
(
in_thread_buf
,
max_value_buf
);
ThreadwiseMaxReduce
::
Reduce
(
in_thread_buf
,
max_value_buf
);
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
I
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
I
)
{
BlockwiseMaxReduce
::
Reduce
(
reduce_work_buf
,
max_value_buf
(
I
));
BlockwiseMaxReduce
::
Reduce
(
reduce_work_buf
,
max_value_buf
(
I
));
...
@@ -118,6 +117,47 @@ struct BlockwiseSoftmax
...
@@ -118,6 +117,47 @@ struct BlockwiseSoftmax
});
});
}
}
template
<
typename
CThreadBuffer
,
typename
WorkspaceBuffer
>
__host__
__device__
void
CalculateRowMax
(
CThreadBuffer
&
in_thread_buf
,
WorkspaceBuffer
&
reduce_work_buf
)
{
// find max value
static_for
<
0
,
MRepeat
,
1
>
{}(
[
&
](
auto
I
)
{
max_value_buf
(
I
)
=
ck
::
NumericLimits
<
AccDataType
>::
Lowest
();
});
ThreadwiseMaxReduce
::
Reduce
(
in_thread_buf
,
max_value_buf
);
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
I
)
{
BlockwiseMaxReduce
::
Reduce
(
reduce_work_buf
,
max_value_buf
(
I
));
block_sync_lds
();
});
};
template
<
typename
CThreadBuffer
,
typename
WorkspaceBuffer
>
__host__
__device__
void
CalculateRowExpSum
(
CThreadBuffer
&
in_thread_buf
,
WorkspaceBuffer
&
reduce_work_buf
,
BufferType
&
max_value_buf_new
)
{
// calculate exp for elements, P=exp(s-max)
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
iK
)
{
auto
offset
=
Number
<
ThreadSliceDesc_M_K
{}.
CalculateOffset
(
make_tuple
(
iM
,
iK
))
>
{};
in_thread_buf
(
offset
)
=
IgnoreNaN
&&
ck
::
math
::
isnan
(
in_thread_buf
[
offset
])
?
0
:
math
::
exp
(
in_thread_buf
[
offset
]
-
max_value_buf_new
(
iM
));
});
});
// sum data
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
I
)
{
sum_value_buf
(
I
)
=
reduce
::
Add
::
template
GetIdentityValue
<
AccDataType
>();
});
ThreadwiseSumReduce
::
Reduce
(
in_thread_buf
,
sum_value_buf
);
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
I
)
{
BlockwiseSumReduce
::
Reduce
(
reduce_work_buf
,
sum_value_buf
(
I
));
block_sync_lds
();
});
}
template
<
typename
CThreadBuffer
,
typename
LSEBuffer
>
template
<
typename
CThreadBuffer
,
typename
LSEBuffer
>
__host__
__device__
void
RunWithPreCalcStats
(
CThreadBuffer
&
in_thread_buf
,
__host__
__device__
void
RunWithPreCalcStats
(
CThreadBuffer
&
in_thread_buf
,
const
LSEBuffer
&
lse_thread_buf
)
const
LSEBuffer
&
lse_thread_buf
)
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_fwd_xdl_cshuffle_v2.hpp
View file @
635ea6a0
...
@@ -1147,6 +1147,21 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
...
@@ -1147,6 +1147,21 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
0
),
0
),
tensor_operation
::
element_wise
::
PassThrough
{}};
tensor_operation
::
element_wise
::
PassThrough
{}};
constexpr
auto
c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
=
gemm1_blockwise_gemm
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
();
constexpr
auto
cm0
=
c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I0
);
constexpr
auto
cn0
=
c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I1
);
constexpr
auto
cm1
=
c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I2
);
constexpr
auto
cn1
=
c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I3
);
constexpr
auto
cm2
=
c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I4
);
constexpr
auto
cn2
=
c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I5
);
constexpr
auto
cn3
=
c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I6
);
constexpr
auto
cn4
=
c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I7
);
constexpr
auto
c_thread_slice_desc_m_n
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
cm0
*
cm1
*
cm2
,
cn0
*
cn1
*
cn2
*
cn3
*
cn4
));
constexpr
auto
c_thread_buf_slice_m
=
c_thread_slice_desc_m_n
.
GetLength
(
I0
);
constexpr
auto
c_thread_buf_slice_n
=
c_thread_slice_desc_m_n
.
GetLength
(
I1
);
do
do
{
{
auto
n_block_data_idx_on_grid
=
auto
n_block_data_idx_on_grid
=
...
@@ -1262,11 +1277,23 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
...
@@ -1262,11 +1277,23 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_multi_index
(
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
));
make_multi_index
(
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
));
}
}
// softmax
// calculate current max
blockwise_softmax
.
CalculateRowMax
(
acc_thread_buf
,
workspace_buf
);
// current max
SoftmaxBuf
&
max
=
blockwise_softmax
.
max_value_buf
;
SoftmaxBuf
&
max
=
blockwise_softmax
.
max_value_buf
;
// accumulated max
running_max_new
=
mathext
::
max
(
max
,
running_max
);
// calculate current exp_sum
blockwise_softmax
.
CalculateRowExpSum
(
acc_thread_buf
,
workspace_buf
,
running_max_new
);
// current exp_sum
SoftmaxBuf
&
sum
=
blockwise_softmax
.
sum_value_buf
;
SoftmaxBuf
&
sum
=
blockwise_softmax
.
sum_value_buf
;
blockwise_softmax
.
Run
(
acc_thread_buf
,
workspace_buf
);
// accumulated exp_sum
running_sum_new
=
mathext
::
exp
(
running_max
-
running_max_new
)
*
running_sum
+
sum
;
constexpr
auto
iterator_offset
=
Number
<
16
*
DropoutStep
>
{};
constexpr
auto
iterator_offset
=
Number
<
16
*
DropoutStep
>
{};
constexpr
auto
iterator_step
=
Number
<
m0
*
n0
*
n1
*
n2
*
n3
*
n4
/
16
/
DropoutStep
>
{};
constexpr
auto
iterator_step
=
Number
<
m0
*
n0
*
n1
*
n2
*
n3
*
n4
/
16
/
DropoutStep
>
{};
...
@@ -1325,11 +1352,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
...
@@ -1325,11 +1352,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
});
});
}
}
// TODO: may convert to log domain
running_max_new
=
mathext
::
max
(
max
,
running_max
);
running_sum_new
=
mathext
::
exp
(
running_max
-
running_max_new
)
*
running_sum
+
mathext
::
exp
(
max
-
running_max_new
)
*
sum
;
// gemm1
// gemm1
{
{
// TODO: explore using dynamic buffer for a1 thread buffer
// TODO: explore using dynamic buffer for a1 thread buffer
...
@@ -1394,31 +1416,14 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
...
@@ -1394,31 +1416,14 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
}
}
}
// end gemm1
}
// end gemm1
constexpr
auto
c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
=
gemm1_blockwise_gemm
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
();
constexpr
auto
cm0
=
c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I0
);
constexpr
auto
cn0
=
c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I1
);
constexpr
auto
cm1
=
c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I2
);
constexpr
auto
cn1
=
c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I3
);
constexpr
auto
cm2
=
c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I4
);
constexpr
auto
cn2
=
c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I5
);
constexpr
auto
cn3
=
c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I6
);
constexpr
auto
cn4
=
c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I7
);
constexpr
auto
c_thread_slice_desc_m_n
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
cm0
*
cm1
*
cm2
,
cn0
*
cn1
*
cn2
*
cn3
*
cn4
));
constexpr
auto
c_thread_buf_slice_m
=
c_thread_slice_desc_m_n
.
GetLength
(
I0
);
constexpr
auto
c_thread_buf_slice_n
=
c_thread_slice_desc_m_n
.
GetLength
(
I1
);
static_for
<
0
,
c_thread_buf_slice_m
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
c_thread_buf_slice_m
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
c_thread_buf_slice_n
,
1
>
{}([
&
](
auto
iN
)
{
static_for
<
0
,
c_thread_buf_slice_n
,
1
>
{}([
&
](
auto
iN
)
{
auto
I
=
Number
<
c_thread_slice_desc_m_n
.
CalculateOffset
(
make_tuple
(
iM
,
iN
))
>
{};
auto
I
=
Number
<
c_thread_slice_desc_m_n
.
CalculateOffset
(
make_tuple
(
iM
,
iN
))
>
{};
FloatGemmAcc
acc1
=
acc1_thread_buf
[
I
];
// P*V
FloatGemmAcc
acc1
=
acc1_thread_buf
[
I
];
// P*V
FloatGemmAcc
c
=
c_thread_buf
[
I
];
// O
FloatGemmAcc
c
=
c_thread_buf
[
I
];
// O
FloatGemmAcc
c_new
=
FloatGemmAcc
c_new
=
(
running_sum
[
iM
]
*
math
::
exp
(
running_max
[
iM
]
-
running_max_new
[
iM
])
*
c
+
math
::
exp
(
running_max
[
iM
]
-
running_max_new
[
iM
])
*
c
+
acc1
;
math
::
exp
(
max
[
iM
]
-
running_max_new
[
iM
])
*
acc1
)
/
// Formula by Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf section 3.1
running_sum_new
[
iM
];
// Formula by Dao et al.,
// https://arxiv.org/pdf/2205.14135v2.pdf section 3.1
c_thread_buf
(
I
)
=
c_new
;
// O_new
c_thread_buf
(
I
)
=
c_new
;
// O_new
});
});
...
@@ -1436,6 +1441,13 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
...
@@ -1436,6 +1441,13 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
block_sync_lds
();
// wait for gemm1 LDS read
block_sync_lds
();
// wait for gemm1 LDS read
}
while
(
++
gemm1_k_block_outer_index
<
num_gemm1_k_block_outer_loop
);
// end j loop
}
while
(
++
gemm1_k_block_outer_index
<
num_gemm1_k_block_outer_loop
);
// end j loop
static_for
<
0
,
c_thread_buf_slice_m
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
c_thread_buf_slice_n
,
1
>
{}([
&
](
auto
iN
)
{
auto
I
=
Number
<
c_thread_slice_desc_m_n
.
CalculateOffset
(
make_tuple
(
iM
,
iN
))
>
{};
c_thread_buf
(
I
)
=
c_thread_buf
[
I
]
/
running_sum
[
iM
];
});
});
// Calculate max + ln(sum) and write out
// Calculate max + ln(sum) and write out
if
constexpr
(
IsLseStoring
)
if
constexpr
(
IsLseStoring
)
...
@@ -1468,10 +1480,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
...
@@ -1468,10 +1480,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
constexpr
index_t
MWave
=
MPerBlock
/
(
MXdlPerWave
*
MPerXdl
);
constexpr
index_t
MWave
=
MPerBlock
/
(
MXdlPerWave
*
MPerXdl
);
constexpr
index_t
NWave
=
Gemm1NPerBlock
/
(
Gemm1NXdlPerWave
*
NPerXdl
);
constexpr
index_t
NWave
=
Gemm1NPerBlock
/
(
Gemm1NXdlPerWave
*
NPerXdl
);
// TODO: hacky, fix it!
constexpr
auto
c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
=
gemm1_blockwise_gemm
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
();
// TODO: hacky, fix it!
// TODO: hacky, fix it!
// c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp is only used to get lengths
// c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp is only used to get lengths
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp
=
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp
=
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_infer_xdl_cshuffle.hpp
View file @
635ea6a0
...
@@ -918,6 +918,21 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
...
@@ -918,6 +918,21 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
auto
d0_thread_copy_lds_to_vgpr
=
typename
D0Operator
::
D0ThreadwiseCopyLdsToVgpr
(
auto
d0_thread_copy_lds_to_vgpr
=
typename
D0Operator
::
D0ThreadwiseCopyLdsToVgpr
(
make_tuple
(
wave_id
[
I0
],
wave_m_n_id
[
I1
],
0
,
0
,
wave_m_n_id
[
I0
],
0
));
make_tuple
(
wave_id
[
I0
],
wave_m_n_id
[
I1
],
0
,
0
,
wave_m_n_id
[
I0
],
0
));
constexpr
auto
c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
=
gemm1_blockwise_gemm
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
();
constexpr
auto
cm0
=
c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I0
);
constexpr
auto
cn0
=
c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I1
);
constexpr
auto
cm1
=
c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I2
);
constexpr
auto
cn1
=
c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I3
);
constexpr
auto
cm2
=
c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I4
);
constexpr
auto
cn2
=
c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I5
);
constexpr
auto
cn3
=
c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I6
);
constexpr
auto
cn4
=
c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I7
);
constexpr
auto
c_thread_slice_desc_m_n
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
cm0
*
cm1
*
cm2
,
cn0
*
cn1
*
cn2
*
cn3
*
cn4
));
constexpr
auto
c_thread_buf_slice_m
=
c_thread_slice_desc_m_n
.
GetLength
(
I0
);
constexpr
auto
c_thread_buf_slice_n
=
c_thread_slice_desc_m_n
.
GetLength
(
I1
);
index_t
gemm1_k_block_outer_index
=
0
;
index_t
gemm1_k_block_outer_index
=
0
;
do
do
{
{
...
@@ -1062,16 +1077,22 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
...
@@ -1062,16 +1077,22 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
}
}
}
}
// softmax
// calculate current max
blockwise_softmax
.
CalculateRowMax
(
acc_thread_buf
,
workspace_buf
);
// current max
SoftmaxBuf
&
max
=
blockwise_softmax
.
max_value_buf
;
SoftmaxBuf
&
max
=
blockwise_softmax
.
max_value_buf
;
SoftmaxBuf
&
sum
=
blockwise_softmax
.
sum_value_buf
;
// accumulated max
running_max_new
=
mathext
::
max
(
max
,
running_max
);
blockwise_softmax
.
Run
(
acc_thread_buf
,
workspace_buf
);
// calculate current exp_sum
blockwise_softmax
.
CalculateRowExpSum
(
acc_thread_buf
,
workspace_buf
,
running_max_new
);
// TODO: may convert to log domain
// current exp_sum
running_max_new
=
mathext
::
max
(
max
,
running_max
);
SoftmaxBuf
&
sum
=
blockwise_softmax
.
sum_value_buf
;
running_sum_new
=
mathext
::
exp
(
running_max
-
running_max_new
)
*
running_sum
+
mathext
::
exp
(
max
-
running_max_new
)
*
sum
;
// accumulated exp_sum
running_sum_new
=
mathext
::
exp
(
running_max
-
running_max_new
)
*
running_sum
+
sum
;
// gemm1
// gemm1
{
{
...
@@ -1137,30 +1158,13 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
...
@@ -1137,30 +1158,13 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
}
}
}
// end gemm1
}
// end gemm1
constexpr
auto
c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
=
gemm1_blockwise_gemm
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
();
constexpr
auto
cm0
=
c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I0
);
constexpr
auto
cn0
=
c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I1
);
constexpr
auto
cm1
=
c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I2
);
constexpr
auto
cn1
=
c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I3
);
constexpr
auto
cm2
=
c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I4
);
constexpr
auto
cn2
=
c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I5
);
constexpr
auto
cn3
=
c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I6
);
constexpr
auto
cn4
=
c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I7
);
constexpr
auto
c_thread_slice_desc_m_n
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
cm0
*
cm1
*
cm2
,
cn0
*
cn1
*
cn2
*
cn3
*
cn4
));
constexpr
auto
c_thread_buf_slice_m
=
c_thread_slice_desc_m_n
.
GetLength
(
I0
);
constexpr
auto
c_thread_buf_slice_n
=
c_thread_slice_desc_m_n
.
GetLength
(
I1
);
static_for
<
0
,
c_thread_buf_slice_m
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
c_thread_buf_slice_m
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
c_thread_buf_slice_n
,
1
>
{}([
&
](
auto
iN
)
{
static_for
<
0
,
c_thread_buf_slice_n
,
1
>
{}([
&
](
auto
iN
)
{
auto
I
=
Number
<
c_thread_slice_desc_m_n
.
CalculateOffset
(
make_tuple
(
iM
,
iN
))
>
{};
auto
I
=
Number
<
c_thread_slice_desc_m_n
.
CalculateOffset
(
make_tuple
(
iM
,
iN
))
>
{};
FloatGemmAcc
acc1
=
acc1_thread_buf
[
I
];
// P*V
FloatGemmAcc
acc1
=
acc1_thread_buf
[
I
];
// P*V
FloatGemmAcc
c
=
c_thread_buf
[
I
];
// O
FloatGemmAcc
c
=
c_thread_buf
[
I
];
// O
FloatGemmAcc
c_new
=
FloatGemmAcc
c_new
=
math
::
exp
(
running_max
[
iM
]
-
running_max_new
[
iM
])
*
c
+
(
running_sum
[
iM
]
*
math
::
exp
(
running_max
[
iM
]
-
running_max_new
[
iM
])
*
c
+
acc1
;
// Formula by Dao et al.,
math
::
exp
(
max
[
iM
]
-
running_max_new
[
iM
])
*
acc1
)
/
running_sum_new
[
iM
];
// Formula by Dao et al.,
// https://arxiv.org/pdf/2205.14135v2.pdf section 3.1
// https://arxiv.org/pdf/2205.14135v2.pdf section 3.1
c_thread_buf
(
I
)
=
c_new
;
// O_new
c_thread_buf
(
I
)
=
c_new
;
// O_new
...
@@ -1179,6 +1183,13 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
...
@@ -1179,6 +1183,13 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
block_sync_lds
();
// wait for gemm1 LDS read
block_sync_lds
();
// wait for gemm1 LDS read
}
while
(
++
gemm1_k_block_outer_index
<
num_gemm1_k_block_outer_loop
);
// end j loop
}
while
(
++
gemm1_k_block_outer_index
<
num_gemm1_k_block_outer_loop
);
// end j loop
static_for
<
0
,
c_thread_buf_slice_m
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
c_thread_buf_slice_n
,
1
>
{}([
&
](
auto
iN
)
{
auto
I
=
Number
<
c_thread_slice_desc_m_n
.
CalculateOffset
(
make_tuple
(
iM
,
iN
))
>
{};
c_thread_buf
(
I
)
=
c_thread_buf
[
I
]
/
running_sum
[
iM
];
});
});
// shuffle C and write out
// shuffle C and write out
{
{
static_assert
(
MXdlPerWave
%
CShuffleMXdlPerWavePerShuffle
==
0
&&
static_assert
(
MXdlPerWave
%
CShuffleMXdlPerWavePerShuffle
==
0
&&
...
@@ -1188,10 +1199,6 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
...
@@ -1188,10 +1199,6 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
constexpr
index_t
MWave
=
MPerBlock
/
(
MXdlPerWave
*
MPerXdl
);
constexpr
index_t
MWave
=
MPerBlock
/
(
MXdlPerWave
*
MPerXdl
);
constexpr
index_t
NWave
=
Gemm1NPerBlock
/
(
Gemm1NXdlPerWave
*
NPerXdl
);
constexpr
index_t
NWave
=
Gemm1NPerBlock
/
(
Gemm1NXdlPerWave
*
NPerXdl
);
// TODO: hacky, fix it!
constexpr
auto
c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
=
gemm1_blockwise_gemm
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
();
// TODO: hacky, fix it!
// TODO: hacky, fix it!
// c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp is only used to get lengths
// c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp is only used to get lengths
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp
=
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp
=
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment