Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
394f9207
Commit
394f9207
authored
Dec 27, 2022
by
Anthony Chang
Browse files
remove unused variables
parent
fc7e83ee
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
9 additions
and
28 deletions
+9
-28
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_fp16.cpp
...oftmax_gemm/batched_multihead_attention_backward_fp16.cpp
+0
-2
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
...id/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
+9
-26
No files found.
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_fp16.cpp
View file @
394f9207
...
@@ -23,8 +23,6 @@ Kernel outputs:
...
@@ -23,8 +23,6 @@ Kernel outputs:
*/
*/
#pragma clang diagnostic ignored "-Wunused-variable" // TODO ANT: remove
#define PRINT_HOST 0
#define PRINT_HOST 0
#include <iostream>
#include <iostream>
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
View file @
394f9207
...
@@ -884,8 +884,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -884,8 +884,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
}
}
template
<
typename
VGridDesc_N0_O_N1_
>
template
<
typename
VGridDesc_N0_O_N1_
>
__device__
static
auto
__device__
static
auto
MakeVGridDesc_O0_N_O1
(
const
VGridDesc_N0_O_N1_
&
v_grid_desc_n0_o_n1
)
MakeVGridDesc_O0_N_O1
(
const
VGridDesc_N0_O_N1_
&
v_grid_desc_n0_o_n1
)
{
{
const
auto
N0
=
v_grid_desc_n0_o_n1
.
GetLength
(
I0
);
const
auto
N0
=
v_grid_desc_n0_o_n1
.
GetLength
(
I0
);
const
auto
O
=
v_grid_desc_n0_o_n1
.
GetLength
(
I1
);
const
auto
O
=
v_grid_desc_n0_o_n1
.
GetLength
(
I1
);
...
@@ -936,8 +935,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -936,8 +935,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
}
}
template
<
typename
KGridDesc_K0_N_K1_
>
template
<
typename
KGridDesc_K0_N_K1_
>
__device__
static
auto
__device__
static
auto
MakeKGridDesc_N0_K_N1
(
const
KGridDesc_K0_N_K1_
&
k_grid_desc_k0_n_k1
)
MakeKGridDesc_N0_K_N1
(
const
KGridDesc_K0_N_K1_
&
k_grid_desc_k0_n_k1
)
{
{
const
auto
K_K0
=
k_grid_desc_k0_n_k1
.
GetLength
(
I0
);
const
auto
K_K0
=
k_grid_desc_k0_n_k1
.
GetLength
(
I0
);
const
auto
N
=
k_grid_desc_k0_n_k1
.
GetLength
(
I1
);
const
auto
N
=
k_grid_desc_k0_n_k1
.
GetLength
(
I1
);
...
@@ -961,8 +959,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -961,8 +959,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
{
{
// B position
// B position
template
<
typename
QGridDesc_K0_M_K1_
>
template
<
typename
QGridDesc_K0_M_K1_
>
__device__
static
auto
__device__
static
auto
MakeQGridDesc_M0_K_M1
(
const
QGridDesc_K0_M_K1_
&
q_grid_desc_k0_m_k1
)
MakeQGridDesc_M0_K_M1
(
const
QGridDesc_K0_M_K1_
&
q_grid_desc_k0_m_k1
)
{
{
const
auto
Q_K0
=
q_grid_desc_k0_m_k1
.
GetLength
(
I0
);
const
auto
Q_K0
=
q_grid_desc_k0_m_k1
.
GetLength
(
I0
);
const
auto
M
=
q_grid_desc_k0_m_k1
.
GetLength
(
I1
);
const
auto
M
=
q_grid_desc_k0_m_k1
.
GetLength
(
I1
);
...
@@ -983,8 +980,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -983,8 +980,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
// C position
// C position
template
<
typename
KGridDesc_K0_N_K1_
>
template
<
typename
KGridDesc_K0_N_K1_
>
__device__
static
auto
__device__
static
auto
MakeKGradGridDesc_N_K
(
const
KGridDesc_K0_N_K1_
&
k_grid_desc_k0_n_k1
)
MakeKGradGridDesc_N_K
(
const
KGridDesc_K0_N_K1_
&
k_grid_desc_k0_n_k1
)
{
{
const
auto
K_K0
=
k_grid_desc_k0_n_k1
.
GetLength
(
I0
);
const
auto
K_K0
=
k_grid_desc_k0_n_k1
.
GetLength
(
I0
);
const
auto
N
=
k_grid_desc_k0_n_k1
.
GetLength
(
I1
);
const
auto
N
=
k_grid_desc_k0_n_k1
.
GetLength
(
I1
);
...
@@ -1202,7 +1198,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -1202,7 +1198,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
const
auto
v_grid_desc_o0_n_o1
=
const
auto
v_grid_desc_o0_n_o1
=
PGradGemmTile_M_N_O
::
MakeVGridDesc_O0_N_O1
(
v_grid_desc_n0_o_n1
);
PGradGemmTile_M_N_O
::
MakeVGridDesc_O0_N_O1
(
v_grid_desc_n0_o_n1
);
// dP:
Gemm A position
blockwise copy
// dP:
A matrix
blockwise copy
auto
pgrad_gemm_tile_ygrad_blockwise_copy
=
auto
pgrad_gemm_tile_ygrad_blockwise_copy
=
typename
Gemm0
::
template
ABlockwiseCopy
<
decltype
(
ygrad_grid_desc_o0_m_o1
)>(
typename
Gemm0
::
template
ABlockwiseCopy
<
decltype
(
ygrad_grid_desc_o0_m_o1
)>(
ygrad_grid_desc_o0_m_o1
,
ygrad_grid_desc_o0_m_o1
,
...
@@ -1212,7 +1208,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -1212,7 +1208,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
make_multi_index
(
0
,
0
,
0
),
make_multi_index
(
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
tensor_operation
::
element_wise
::
PassThrough
{});
// dP:
Gemm B position
blockwise copy
// dP:
B matrix
blockwise copy
auto
pgrad_gemm_tile_v_blockwise_copy
=
auto
pgrad_gemm_tile_v_blockwise_copy
=
typename
Gemm0
::
template
BBlockwiseCopy
<
decltype
(
v_grid_desc_o0_n_o1
)>(
typename
Gemm0
::
template
BBlockwiseCopy
<
decltype
(
v_grid_desc_o0_n_o1
)>(
v_grid_desc_o0_n_o1
,
v_grid_desc_o0_n_o1
,
...
@@ -1283,10 +1279,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -1283,10 +1279,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
//
//
// Blockwise softmax
// Blockwise softmax
//
//
auto
workspace_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatGemmAcc
*>
(
p_shared
)
+
SharedMemTrait
::
reduction_space_offset
,
SharedMemTrait
::
reduction_space_size_aligned
);
// get acc0 8D thread cluster
// get acc0 8D thread cluster
constexpr
auto
thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4
=
constexpr
auto
thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4
=
s_blockwise_gemm
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
().
GetLengths
()
/
s_blockwise_gemm
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
().
GetLengths
()
/
...
@@ -1381,7 +1373,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -1381,7 +1373,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
Gemm2
::
b_block_desc_m0_o_m1
.
GetElementSpaceSize
());
Gemm2
::
b_block_desc_m0_o_m1
.
GetElementSpaceSize
());
// dV: transform input and output tensor descriptors
// dV: transform input and output tensor descriptors
const
auto
vgrad_grid_desc_n0_o0_n1_o1_n2_o2
=
const
auto
vgrad_grid_desc_n0_o0_n1_o1_n2_o2
_o3_o4
=
Gemm2
::
MakeCGridDesc_N0_O0_N1_O1_N2_O2_O3_O4
(
vgrad_grid_desc_n_o
);
Gemm2
::
MakeCGridDesc_N0_O0_N1_O1_N2_O2_O3_O4
(
vgrad_grid_desc_n_o
);
// dV: A matrix VGPR-to-LDS blockwise copy
// dV: A matrix VGPR-to-LDS blockwise copy
...
@@ -1390,9 +1382,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -1390,9 +1382,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
Gemm2
::
MakeAThreadOriginOnBlock_M0_N0_M1_N1_M2_N2_N3_N4
(),
Gemm2
::
MakeAThreadOriginOnBlock_M0_N0_M1_N1_M2_N2_N3_N4
(),
tensor_operation
::
element_wise
::
PassThrough
{}};
tensor_operation
::
element_wise
::
PassThrough
{}};
constexpr
auto
vgrad_gemm_tile_p_block_slice_window_iterator
=
typename
Gemm2
::
ASrcBlockSliceWindowIterator
{};
// dV: B matrix global-to-LDS blockwise copy
// dV: B matrix global-to-LDS blockwise copy
auto
vgrad_gemm_tile_ygrad_blockwise_copy
=
auto
vgrad_gemm_tile_ygrad_blockwise_copy
=
typename
Gemm2
::
template
BBlockwiseCopy
<
decltype
(
ygrad_grid_desc_m0_o_m1
)>(
typename
Gemm2
::
template
BBlockwiseCopy
<
decltype
(
ygrad_grid_desc_m0_o_m1
)>(
...
@@ -1411,9 +1400,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -1411,9 +1400,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
auto
v_slash_k_grad_thread_buf
=
v_slash_k_grad_blockwise_gemm
.
GetCThreadBuffer
();
auto
v_slash_k_grad_thread_buf
=
v_slash_k_grad_blockwise_gemm
.
GetCThreadBuffer
();
// dV: C VGPR-to-global copy
// dV: C VGPR-to-global copy
const
auto
vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4
=
Gemm2
::
MakeCGridDesc_N0_O0_N1_O1_N2_O2_O3_O4
(
vgrad_grid_desc_n_o
);
const
auto
vgrad_thread_origin_on_grid_n0_o0_n1_o1_n2_o2_o3_o4
=
const
auto
vgrad_thread_origin_on_grid_n0_o0_n1_o1_n2_o2_o3_o4
=
Gemm2
::
GetCThreadOriginOnBlock_N0_O0_N1_O1_N2_O2_O3_O4
()
+
Gemm2
::
GetCThreadOriginOnBlock_N0_O0_N1_O1_N2_O2_O3_O4
()
+
make_multi_index
(
make_multi_index
(
...
@@ -1430,6 +1416,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -1430,6 +1416,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
KGradGemmTile_N_K_M
::
MakeQGridDesc_M0_K_M1
(
q_grid_desc_k0_m_k1
);
KGradGemmTile_N_K_M
::
MakeQGridDesc_M0_K_M1
(
q_grid_desc_k0_m_k1
);
const
auto
kgrad_grid_desc_n_k
=
const
auto
kgrad_grid_desc_n_k
=
KGradGemmTile_N_K_M
::
MakeKGradGridDesc_N_K
(
k_grid_desc_k0_n_k1
);
KGradGemmTile_N_K_M
::
MakeKGradGridDesc_N_K
(
k_grid_desc_k0_n_k1
);
const
auto
kgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4
=
Gemm2
::
MakeCGridDesc_N0_O0_N1_O1_N2_O2_O3_O4
(
kgrad_grid_desc_n_k
);
// dK: A matrix VGPR-to-LDS blockwise copy
// dK: A matrix VGPR-to-LDS blockwise copy
auto
kgrad_gemm_tile_sgrad_thread_copy_vgpr_to_lds
=
typename
Gemm2
::
ABlockwiseCopy
{
auto
kgrad_gemm_tile_sgrad_thread_copy_vgpr_to_lds
=
typename
Gemm2
::
ABlockwiseCopy
{
...
@@ -1453,9 +1441,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -1453,9 +1441,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
/* reuse v_slash_k_grad_blockwise_gemm, v_slash_k_grad_thread_buf */
/* reuse v_slash_k_grad_blockwise_gemm, v_slash_k_grad_thread_buf */
// dK: C VGPR-to-global copy
// dK: C VGPR-to-global copy
const
auto
kgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4
=
Gemm2
::
MakeCGridDesc_N0_O0_N1_O1_N2_O2_O3_O4
(
kgrad_grid_desc_n_k
);
const
auto
kgrad_thread_origin_on_grid_n0_o0_n1_o1_n2_o2_o3_o4
=
const
auto
kgrad_thread_origin_on_grid_n0_o0_n1_o1_n2_o2_o3_o4
=
Gemm2
::
GetCThreadOriginOnBlock_N0_O0_N1_O1_N2_O2_O3_O4
()
+
Gemm2
::
GetCThreadOriginOnBlock_N0_O0_N1_O1_N2_O2_O3_O4
()
+
make_multi_index
(
make_multi_index
(
...
@@ -1792,8 +1777,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -1792,8 +1777,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
constexpr
auto
pgrad_thread_idx
=
pgrad_thread_tile_iterator
.
GetIndex
(
i
);
constexpr
auto
pgrad_thread_idx
=
pgrad_thread_tile_iterator
.
GetIndex
(
i
);
constexpr
auto
m
=
constexpr
auto
m
=
pgrad_thread_idx_to_m_n_adaptor
.
CalculateBottomIndex
(
pgrad_thread_idx
)[
I0
];
pgrad_thread_idx_to_m_n_adaptor
.
CalculateBottomIndex
(
pgrad_thread_idx
)[
I0
];
constexpr
auto
n
=
pgrad_thread_idx_to_m_n_adaptor
.
CalculateBottomIndex
(
pgrad_thread_idx
)[
I1
];
// dS and P has same thread buf layout
// dS and P has same thread buf layout
sgrad_thread_buf
(
i
)
=
s_slash_p_thread_buf
[
i
]
*
sgrad_thread_buf
(
i
)
=
s_slash_p_thread_buf
[
i
]
*
(
pgrad_thread_buf
[
i
]
-
y_dot_ygrad_thread_buf
[
Number
<
m
>
{}]);
(
pgrad_thread_buf
[
i
]
-
y_dot_ygrad_thread_buf
[
Number
<
m
>
{}]);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment