Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
73416101
Commit
73416101
authored
Mar 03, 2022
by
ltqin
Browse files
fix atomic and set operator choice
parent
4f940d01
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
159 additions
and
73 deletions
+159
-73
device_operation/include/device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
...e_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
+128
-44
example/14_conv2d_backward_weight_xdl/main.cpp
example/14_conv2d_backward_weight_xdl/main.cpp
+31
-29
No files found.
device_operation/include/device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
View file @
73416101
...
@@ -189,7 +189,7 @@ struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
...
@@ -189,7 +189,7 @@ struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
ADataType
,
// TODO: distinguish A/B datatype
ADataType
,
// TODO: distinguish A/B datatype
AccDataType
,
AccDataType
,
CDataType
,
CDataType
,
InMemoryDataOperationEnum_t
::
AtomicAdd
,
InMemoryDataOperationEnum_t
::
Set
,
AGridDesc_K0_M_K1
,
AGridDesc_K0_M_K1
,
BGridDesc_K0_N_K1
,
BGridDesc_K0_N_K1
,
CGridDesc_M_N
,
CGridDesc_M_N
,
...
@@ -225,6 +225,46 @@ struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
...
@@ -225,6 +225,46 @@ struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
CBlockTransferScalarPerVector_NWaveNPerXdl
,
CBlockTransferScalarPerVector_NWaveNPerXdl
,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
>
;
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
>
;
using
GridwiseGemmAtomicAdd
=
GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
<
BlockSize
,
ADataType
,
// TODO: distinguish A/B datatype
AccDataType
,
CDataType
,
InMemoryDataOperationEnum_t
::
AtomicAdd
,
AGridDesc_K0_M_K1
,
BGridDesc_K0_N_K1
,
CGridDesc_M_N
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
MPerBlock
,
NPerBlock
,
K0PerBlock
,
MPerXdl
,
NPerXdl
,
K1
,
MXdlPerWave
,
NXdlPerWave
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_K1
,
false
,
// AThreadTransferSrcResetCoordinateAfterRun,
ABlockLdsAddExtraM
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_K1
,
false
,
// BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsAddExtraN
,
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
CBlockTransferScalarPerVector_NWaveNPerXdl
,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
>
;
// Argument
// Argument
using
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
using
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
decltype
(
GridwiseGemm
::
MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
(
CGridDesc_M_N
{}));
decltype
(
GridwiseGemm
::
MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
(
CGridDesc_M_N
{}));
...
@@ -335,23 +375,27 @@ struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
...
@@ -335,23 +375,27 @@ struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
{
{
using
Argument
=
DeviceOp
::
Argument
;
using
Argument
=
DeviceOp
::
Argument
;
float
Run
(
const
Argument
&
arg
,
int
nrepeat
=
1
)
void
ShowInfo
(
const
Argument
&
arg
)
{
{
{
std
::
cout
<<
"arg.a_grid_desc_kbatch_k0_m_k1_{"
std
::
cout
<<
"arg.a_grid_desc_kbatch_k0_m_k1_{"
<<
arg
.
a_grid_desc_kbatch_k0_m_k1_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
a_grid_desc_kbatch_k0_m_k1_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
a_grid_desc_kbatch_k0_m_k1_
.
GetLength
(
I1
)
<<
", "
<<
arg
.
a_grid_desc_kbatch_k0_m_k1_
.
GetLength
(
I1
)
<<
", "
<<
arg
.
a_grid_desc_kbatch_k0_m_k1_
.
GetLength
(
I2
)
<<
", "
<<
arg
.
a_grid_desc_kbatch_k0_m_k1_
.
GetLength
(
I2
)
<<
"}"
<<
std
::
endl
;
<<
arg
.
a_grid_desc_kbatch_k0_m_k1_
.
GetLength
(
I3
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.b_grid_desc_kbatch_k0_n_k1_{"
std
::
cout
<<
"arg.b_grid_desc_kbatch_k0_n_k1_{"
<<
arg
.
b_grid_desc_kbatch_k0_n_k1_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
b_grid_desc_kbatch_k0_n_k1_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
b_grid_desc_kbatch_k0_n_k1_
.
GetLength
(
I1
)
<<
", "
<<
arg
.
b_grid_desc_kbatch_k0_n_k1_
.
GetLength
(
I1
)
<<
", "
<<
arg
.
b_grid_desc_kbatch_k0_n_k1_
.
GetLength
(
I2
)
<<
"}"
<<
std
::
endl
;
<<
arg
.
b_grid_desc_kbatch_k0_n_k1_
.
GetLength
(
I2
)
<<
", "
<<
arg
.
b_grid_desc_kbatch_k0_n_k1_
.
GetLength
(
I3
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.c_grid_desc_m_n_{ "
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.c_grid_desc_m_n_{ "
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
}
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
}
float
Run
(
const
Argument
&
arg
,
int
nrepeat
=
1
)
{
ShowInfo
(
arg
);
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_kbatch_k0_m_k1_
,
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_kbatch_k0_m_k1_
,
arg
.
b_grid_desc_kbatch_k0_n_k1_
,
arg
.
b_grid_desc_kbatch_k0_n_k1_
,
arg
.
c_grid_desc_m_n_
,
arg
.
c_grid_desc_m_n_
,
...
@@ -418,37 +462,77 @@ struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
...
@@ -418,37 +462,77 @@ struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
if
(
has_main_k0_block_loop
)
if
(
has_main_k0_block_loop
)
{
{
const
auto
kernel
=
kernel_gemm_xdlops_v2r4r2
<
if
(
kbatch
==
1
)
GridwiseGemm
,
{
ADataType
,
// TODO: distiguish A/B datatype
const
auto
kernel
=
kernel_gemm_xdlops_v2r4r2
<
CDataType
,
GridwiseGemm
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
ADataType
,
// TODO: distiguish A/B datatype
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
CDataType
,
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
OutElementwiseOperation
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
InElementwiseOperation
,
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
WeiElementwiseOperation
,
OutElementwiseOperation
,
remove_reference_t
<
DeviceOp
::
Block2CTileMap
>
,
InElementwiseOperation
,
true
>
;
WeiElementwiseOperation
,
remove_reference_t
<
DeviceOp
::
Block2CTileMap
>
,
Run
(
kernel
);
true
>
;
Run
(
kernel
);
}
else
{
const
auto
kernel
=
kernel_gemm_xdlops_v2r4r2
<
GridwiseGemmAtomicAdd
,
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
OutElementwiseOperation
,
InElementwiseOperation
,
WeiElementwiseOperation
,
remove_reference_t
<
DeviceOp
::
Block2CTileMap
>
,
true
>
;
Run
(
kernel
);
}
}
}
else
else
{
{
const
auto
kernel
=
kernel_gemm_xdlops_v2r4r2
<
if
(
kbatch
==
1
)
GridwiseGemm
,
{
ADataType
,
// TODO: distiguish A/B datatype
const
auto
kernel
=
kernel_gemm_xdlops_v2r4r2
<
CDataType
,
GridwiseGemm
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
ADataType
,
// TODO: distiguish A/B datatype
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
CDataType
,
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
OutElementwiseOperation
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
InElementwiseOperation
,
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
WeiElementwiseOperation
,
OutElementwiseOperation
,
remove_reference_t
<
DeviceOp
::
Block2CTileMap
>
,
InElementwiseOperation
,
false
>
;
WeiElementwiseOperation
,
remove_reference_t
<
DeviceOp
::
Block2CTileMap
>
,
Run
(
kernel
);
false
>
;
Run
(
kernel
);
}
else
{
const
auto
kernel
=
kernel_gemm_xdlops_v2r4r2
<
GridwiseGemmAtomicAdd
,
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
OutElementwiseOperation
,
InElementwiseOperation
,
WeiElementwiseOperation
,
remove_reference_t
<
DeviceOp
::
Block2CTileMap
>
,
false
>
;
Run
(
kernel
);
}
}
}
return
ave_time
;
return
ave_time
;
...
...
example/14_conv2d_backward_weight_xdl/main.cpp
View file @
73416101
...
@@ -50,23 +50,23 @@ using DeviceConvWrWInstance = ck::tensor_operation::device::
...
@@ -50,23 +50,23 @@ using DeviceConvWrWInstance = ck::tensor_operation::device::
32
,
// NPerXdl
32
,
// NPerXdl
2
,
// MXdlPerWave
2
,
// MXdlPerWave
2
,
// NXdlPerWave
2
,
// NXdlPerWave
S
<
1
,
4
,
16
,
4
>
,
// ABlockTransferThreadClusterLengths_K0_M_K1
S
<
1
,
4
,
16
,
4
>
,
// ABlockTransferThreadClusterLengths_K0_M_K1
S
<
0
,
3
,
1
,
2
>
,
// ABlockTransferThreadClusterArrangeOrder
S
<
0
,
3
,
1
,
2
>
,
// ABlockTransferThreadClusterArrangeOrder
S
<
0
,
2
,
1
,
3
>
,
// ABlockTransferSrcAccessOrder
S
<
0
,
2
,
1
,
3
>
,
// ABlockTransferSrcAccessOrder
2
,
// ABlockTransferSrcVectorDim
2
,
// ABlockTransferSrcVectorDim
8
,
// ABlockTransferSrcScalarPerVector
8
,
// ABlockTransferSrcScalarPerVector
2
,
// ABlockTransferDstScalarPerVector_K1
2
,
// ABlockTransferDstScalarPerVector_K1
true
,
// ABlockLdsAddExtraM
true
,
// ABlockLdsAddExtraM
S
<
1
,
4
,
16
,
4
>
,
// BBlockTransferThreadClusterLengths_K0_N_K1
S
<
1
,
4
,
16
,
4
>
,
// BBlockTransferThreadClusterLengths_K0_N_K1
S
<
0
,
3
,
1
,
2
>
,
// BBlockTransferThreadClusterArrangeOrder
S
<
0
,
3
,
1
,
2
>
,
// BBlockTransferThreadClusterArrangeOrder
S
<
0
,
2
,
1
,
3
>
,
// BBlockTransferSrcAccessOrder
S
<
0
,
2
,
1
,
3
>
,
// BBlockTransferSrcAccessOrder
2
,
// BBlockTransferSrcVectorDim
2
,
// BBlockTransferSrcVectorDim
8
,
// BBlockTransferSrcScalarPerVector
8
,
// BBlockTransferSrcScalarPerVector
2
,
// BBlockTransferDstScalarPerVector_K1
2
,
// BBlockTransferDstScalarPerVector_K1
true
,
// BBlockLdsAddExtraN
true
,
// BBlockLdsAddExtraN
1
,
// CShuffleMXdlPerWavePerShuffle
1
,
// CShuffleMXdlPerWavePerShuffle
1
,
// CShuffleNXdlPerWavePerShuffle
1
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
16
,
1
,
4
>
,
//
S
<
1
,
32
,
1
,
4
>
,
//
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
>
;
// CBlockTransferScalarPerVector_NWaveNPerXdl
8
>
;
// CBlockTransferScalarPerVector_NWaveNPerXdl
// clang-format on
// clang-format on
...
@@ -79,6 +79,7 @@ int main(int argc, char* argv[])
...
@@ -79,6 +79,7 @@ int main(int argc, char* argv[])
int
init_method
=
0
;
int
init_method
=
0
;
int
nrepeat
=
5
;
int
nrepeat
=
5
;
int
do_log
=
0
;
int
do_log
=
0
;
int
split_k
=
1
;
// Conv shape
// Conv shape
ck
::
index_t
N
=
128
;
ck
::
index_t
N
=
128
;
...
@@ -96,14 +97,14 @@ int main(int argc, char* argv[])
...
@@ -96,14 +97,14 @@ int main(int argc, char* argv[])
ck
::
index_t
in_left_pad_w
=
1
;
ck
::
index_t
in_left_pad_w
=
1
;
ck
::
index_t
in_right_pad_h
=
1
;
ck
::
index_t
in_right_pad_h
=
1
;
ck
::
index_t
in_right_pad_w
=
1
;
ck
::
index_t
in_right_pad_w
=
1
;
ck
::
index_t
split_k
=
1
;
if
(
argc
==
5
)
if
(
argc
==
6
)
{
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
nrepeat
=
std
::
stoi
(
argv
[
3
]);
nrepeat
=
std
::
stoi
(
argv
[
3
]);
do_log
=
std
::
stoi
(
argv
[
4
]);
do_log
=
std
::
stoi
(
argv
[
4
]);
split_k
=
std
::
stoi
(
argv
[
5
]);
}
}
else
if
(
argc
==
21
)
else
if
(
argc
==
21
)
{
{
...
@@ -111,23 +112,23 @@ int main(int argc, char* argv[])
...
@@ -111,23 +112,23 @@ int main(int argc, char* argv[])
init_method
=
std
::
stoi
(
argv
[
2
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
nrepeat
=
std
::
stoi
(
argv
[
3
]);
nrepeat
=
std
::
stoi
(
argv
[
3
]);
do_log
=
std
::
stoi
(
argv
[
4
]);
do_log
=
std
::
stoi
(
argv
[
4
]);
split_k
=
std
::
stoi
(
argv
[
5
]);
N
=
std
::
stoi
(
argv
[
5
]);
K
=
std
::
stoi
(
argv
[
6
]);
N
=
std
::
stoi
(
argv
[
6
]);
C
=
std
::
stoi
(
argv
[
7
]);
K
=
std
::
stoi
(
argv
[
7
]);
Y
=
std
::
stoi
(
argv
[
8
]);
C
=
std
::
stoi
(
argv
[
8
]);
X
=
std
::
stoi
(
argv
[
9
]);
Y
=
std
::
stoi
(
argv
[
9
]);
Hi
=
std
::
stoi
(
argv
[
10
]);
X
=
std
::
stoi
(
argv
[
10
]);
W
i
=
std
::
stoi
(
argv
[
11
]);
H
i
=
std
::
stoi
(
argv
[
11
]);
conv_stride_h
=
std
::
stoi
(
argv
[
12
]);
Wi
=
std
::
stoi
(
argv
[
12
]);
conv_stride_
w
=
std
::
stoi
(
argv
[
13
]);
conv_stride_
h
=
std
::
stoi
(
argv
[
13
]);
conv_
dilation_h
=
std
::
stoi
(
argv
[
14
]);
conv_
stride_w
=
std
::
stoi
(
argv
[
14
]);
conv_dilation_
w
=
std
::
stoi
(
argv
[
15
]);
conv_dilation_
h
=
std
::
stoi
(
argv
[
15
]);
in_left_pad_h
=
std
::
stoi
(
argv
[
16
]);
conv_dilation_w
=
std
::
stoi
(
argv
[
16
]);
in_left_pad_
w
=
std
::
stoi
(
argv
[
17
]);
in_left_pad_
h
=
std
::
stoi
(
argv
[
17
]);
in_
righ
t_pad_
h
=
std
::
stoi
(
argv
[
18
]);
in_
lef
t_pad_
w
=
std
::
stoi
(
argv
[
18
]);
in_right_pad_
w
=
std
::
stoi
(
argv
[
19
]);
in_right_pad_
h
=
std
::
stoi
(
argv
[
19
]);
split_k
=
std
::
stoi
(
argv
[
20
]);
in_right_pad_w
=
std
::
stoi
(
argv
[
20
]);
}
}
else
else
{
{
...
@@ -231,9 +232,10 @@ int main(int argc, char* argv[])
...
@@ -231,9 +232,10 @@ int main(int argc, char* argv[])
if
(
!
conv
.
IsSupportedArgument
(
argument
))
if
(
!
conv
.
IsSupportedArgument
(
argument
))
{
{
throw
std
::
runtime_error
(
std
::
cout
<<
"wrong! device_conv with the specified compilation parameters does "
"wrong! device_conv with the specified compilation parameters does "
"not support this Conv problem"
"not support this Conv problem"
);
<<
std
::
endl
;
return
1
;
}
}
float
ave_time
=
invoker
.
Run
(
argument
,
nrepeat
);
float
ave_time
=
invoker
.
Run
(
argument
,
nrepeat
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment