Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
61f4a7ee
Commit
61f4a7ee
authored
Dec 28, 2022
by
Anthony Chang
Browse files
implement scaling
parent
aa0ee8e2
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
19 additions
and
27 deletions
+19
-27
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_fp16.cpp
...oftmax_gemm/batched_multihead_attention_backward_fp16.cpp
+1
-11
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_v1.hpp
..._batched_multihead_attention_backward_xdl_cshuffle_v1.hpp
+18
-16
No files found.
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_fp16.cpp
View file @
61f4a7ee
...
...
@@ -212,11 +212,6 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
#endif
// P = Softmax(S)
// >>> scipy.special.softmax(numpy.eye(4), 1)
// array([[0.47536689, 0.1748777 , 0.1748777 , 0.1748777 ],
// [0.1748777 , 0.47536689, 0.1748777 , 0.1748777 ],
// [0.1748777 , 0.1748777 , 0.47536689, 0.1748777 ],
// [0.1748777 , 0.1748777 , 0.1748777 , 0.47536689]])
auto
ref_softmax
=
ReferenceSoftmaxInstance
{};
auto
ref_softmax_invoker
=
ref_softmax
.
MakeInvoker
();
auto
ref_softmax_argument
=
ref_softmax
.
MakeArgument
(
s_g_m_n
,
p_g_m_n
,
1
,
0
,
{
2
},
&
lse_g_m
);
...
...
@@ -249,8 +244,7 @@ int run(int argc, char* argv[])
ck
::
index_t
G0
=
3
;
ck
::
index_t
G1
=
2
;
// float alpha = 1.f / std::sqrt(K); // TODO: make scaling aware
float
alpha
=
1.
f
;
float
alpha
=
1.
f
/
std
::
sqrt
(
K
);
bool
input_permute
=
false
;
bool
output_permute
=
false
;
...
...
@@ -488,10 +482,6 @@ int run(int argc, char* argv[])
return
0
;
}
if
(
alpha
!=
1.0
f
)
{
std
::
cout
<<
"not yet implemented scaling"
<<
std
::
endl
;
// TODO: make scaling aware
}
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
time_kernel
});
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_v1.hpp
View file @
61f4a7ee
...
...
@@ -24,7 +24,7 @@ template <typename DataType,
typename
FloatLSE
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
Acc
ElementwiseOperation
,
typename
S
ElementwiseOperation
,
typename
B1ElementwiseOperation
,
typename
CElementwiseOperation
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
...
...
@@ -816,13 +816,14 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
return
to_multi_index
(
BlockwiseGemm
::
CalculateCThreadOriginDataIndex8D
(
I0
,
I0
,
I0
,
I0
));
}
template
<
typename
CGridDesc_N0_O0_N1_O1_N2_O2_O3_O4
>
template
<
typename
CGridDesc_N0_O0_N1_O1_N2_O2_O3_O4
,
typename
ElementwiseOp
=
tensor_operation
::
element_wise
::
PassThrough
>
using
CBlockwiseCopy
=
ThreadwiseTensorSliceTransfer_v1r3
<
FloatGemmAcc
,
DataType
,
decltype
(
c_thread_desc_n0_o0_n1_o1_n2_o2_o3_o4
),
CGridDesc_N0_O0_N1_O1_N2_O2_O3_O4
,
tensor_operation
::
e
lement
_
wise
::
PassThrough
,
// CElementwiseOperation
E
lementwise
Op
,
// CElementwiseOperation
decltype
(
c_thread_desc_n0_o0_n1_o1_n2_o2_o3_o4
.
GetLengths
()),
// SliceLengths
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
// AccessOrder
7
,
// VectorDim
...
...
@@ -1083,7 +1084,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
void
*
__restrict__
p_shared
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
Acc
ElementwiseOperation
&
acc
_element_op
,
const
S
ElementwiseOperation
&
s
_element_op
,
const
B1ElementwiseOperation
&
b1_element_op
,
const
CElementwiseOperation
&
c_element_op
,
const
QGridDesc_K0_M_K1
&
q_grid_desc_k0_m_k1
,
...
...
@@ -1446,11 +1447,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
make_multi_index
(
I0
,
block_work_idx
[
I1
]
*
Gemm2Params_N_O_M
::
GemmORepeat
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
);
auto
kgrad_thread_copy_vgpr_to_global
=
typename
Gemm2
::
template
CBlockwiseCopy
<
decltype
(
kgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4
)
>(
kgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4
,
kgrad_thread_origin_on_grid_n0_o0_n1_o1_n2_o2_o3_o4
,
tensor_operation
::
element_wise
::
PassThrough
{}
);
auto
kgrad_thread_copy_vgpr_to_global
=
typename
Gemm2
::
template
CBlockwiseCopy
<
decltype
(
kgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4
)
,
decltype
(
s_element_op
)>(
kgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4
,
kgrad_thread_origin_on_grid_n0_o0_n1_o1_n2_o2_o3_o4
,
s_element_op
);
//
// set up Y dot dY
...
...
@@ -1673,19 +1674,20 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
}
else
{
acc
_element_op
(
s_slash_p_thread_buf
(
i
),
s_slash_p_thread_buf
[
i
]);
s
_element_op
(
s_slash_p_thread_buf
(
i
),
s_slash_p_thread_buf
[
i
]);
}
});
}
else
{
static_for
<
0
,
s_slash_p_thread_buf
.
Size
(),
1
>
{}(
[
&
](
auto
i
)
{
acc
_element_op
(
acc_thread_buf
(
i
),
s_slash_p_thread_buf
[
i
]);
});
[
&
](
auto
i
)
{
s
_element_op
(
acc_thread_buf
(
i
),
s_slash_p_thread_buf
[
i
]);
});
}
block_sync_lds
();
// wait for lds read in gemm0 blockwise gemm
// P_i: = softmax(S_i:)
// P_i: = softmax(scalar * S_i:)
// scaling is already performed in the preceding statements with s_element_op
blockwise_softmax
.
RunWithPreCalcStats
(
s_slash_p_thread_buf
,
lse_thread_buf
);
block_sync_lds
();
// wait for gemm1 LDS read
...
...
@@ -1783,7 +1785,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
});
// gemm dQ
// dQ = dS * K
// dQ =
scalar *
dS * K
{
// TODO: explore using dynamic buffer for a1 thread buffer
// For a1_blockwise_copy, the goal is to satisfy pipeline requirements RunRead(),
...
...
@@ -1846,7 +1848,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
}
}
// end gemm dQ
// dK = dS^T * dQ
// dK =
scalar *
dS^T * dQ
v_slash_k_grad_thread_buf
.
Clear
();
static_for
<
0
,
num_gemm2_loop
,
1
>
{}([
&
](
auto
gemm2_loop_idx
)
{
// gemm dK
// load KGrad Gemm B
...
...
@@ -2008,7 +2010,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
FloatCShuffle
,
decltype
(
c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
),
decltype
(
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
),
tensor_operation
::
e
lement
_
wise
::
PassThrough
,
SE
lementwise
Operation
,
Sequence
<
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
I1
,
...
...
@@ -2032,7 +2034,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
n_thread_data_on_block_idx
[
I2
],
n_thread_data_on_block_idx
[
I3
],
n_thread_data_on_block_idx
[
I4
]),
tensor_operation
::
element_wise
::
PassThrough
{}
};
s_element_op
};
// shuffle: blockwise copy C from LDS to global
auto
c_shuffle_block_copy_lds_to_global
=
ThreadGroupTensorSliceTransfer_v6r1
<
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment