Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
df22ba01
Commit
df22ba01
authored
Mar 02, 2022
by
ltqin
Browse files
start to use atomic add
parent
162359b6
Changes
2
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
186 additions
and
178 deletions
+186
-178
device_operation/include/device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
...e_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
+154
-163
example/14_conv2d_backward_weight_xdl/main.cpp
example/14_conv2d_backward_weight_xdl/main.cpp
+32
-15
No files found.
device_operation/include/device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
View file @
df22ba01
This diff is collapsed.
Click to expand it.
example/14_conv2d_backward_weight_xdl/main.cpp
View file @
df22ba01
...
@@ -50,23 +50,23 @@ using DeviceConvWrWInstance = ck::tensor_operation::device::
...
@@ -50,23 +50,23 @@ using DeviceConvWrWInstance = ck::tensor_operation::device::
32
,
// NPerXdl
32
,
// NPerXdl
2
,
// MXdlPerWave
2
,
// MXdlPerWave
2
,
// NXdlPerWave
2
,
// NXdlPerWave
S
<
4
,
16
,
4
>
,
// ABlockTransferThreadClusterLengths_K0_M_K1
S
<
1
,
4
,
16
,
4
>
,
// ABlockTransferThreadClusterLengths_K0_M_K1
S
<
2
,
0
,
1
>
,
// ABlockTransferThreadClusterArrangeOrder
S
<
0
,
3
,
1
,
2
>
,
// ABlockTransferThreadClusterArrangeOrder
S
<
1
,
0
,
2
>
,
// ABlockTransferSrcAccessOrder
S
<
0
,
2
,
1
,
3
>
,
// ABlockTransferSrcAccessOrder
1
,
// ABlockTransferSrcVectorDim
2
,
// ABlockTransferSrcVectorDim
8
,
// ABlockTransferSrcScalarPerVector
8
,
// ABlockTransferSrcScalarPerVector
2
,
// ABlockTransferDstScalarPerVector_K1
2
,
// ABlockTransferDstScalarPerVector_K1
true
,
// ABlockLdsAddExtraM
true
,
// ABlockLdsAddExtraM
S
<
4
,
16
,
4
>
,
// BBlockTransferThreadClusterLengths_K0_N_K1
S
<
1
,
4
,
16
,
4
>
,
// BBlockTransferThreadClusterLengths_K0_N_K1
S
<
2
,
0
,
1
>
,
// BBlockTransferThreadClusterArrangeOrder
S
<
0
,
3
,
1
,
2
>
,
// BBlockTransferThreadClusterArrangeOrder
S
<
1
,
0
,
2
>
,
// BBlockTransferSrcAccessOrder
S
<
0
,
2
,
1
,
3
>
,
// BBlockTransferSrcAccessOrder
1
,
// BBlockTransferSrcVectorDim
2
,
// BBlockTransferSrcVectorDim
8
,
// BBlockTransferSrcScalarPerVector
8
,
// BBlockTransferSrcScalarPerVector
2
,
// BBlockTransferDstScalarPerVector_K1
2
,
// BBlockTransferDstScalarPerVector_K1
true
,
// BBlockLdsAddExtraN
true
,
// BBlockLdsAddExtraN
1
,
// CShuffleMXdlPerWavePerShuffle
1
,
// CShuffleMXdlPerWavePerShuffle
1
,
// CShuffleNXdlPerWavePerShuffle
1
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
1
,
32
,
1
,
1
,
8
>
,
//
CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
S
<
1
,
1
6
,
1
,
4
>
,
//
8
>
;
// CBlockTransferScalarPerVector_NWaveNPerXdl
8
>
;
// CBlockTransferScalarPerVector_NWaveNPerXdl
// clang-format on
// clang-format on
...
@@ -82,7 +82,7 @@ int main(int argc, char* argv[])
...
@@ -82,7 +82,7 @@ int main(int argc, char* argv[])
// Conv shape
// Conv shape
ck
::
index_t
N
=
128
;
ck
::
index_t
N
=
128
;
ck
::
index_t
K
=
256
;
ck
::
index_t
K
=
256
;
ck
::
index_t
C
=
1
9
2
;
ck
::
index_t
C
=
12
8
;
ck
::
index_t
Y
=
3
;
ck
::
index_t
Y
=
3
;
ck
::
index_t
X
=
3
;
ck
::
index_t
X
=
3
;
ck
::
index_t
Hi
=
71
;
ck
::
index_t
Hi
=
71
;
...
@@ -95,6 +95,7 @@ int main(int argc, char* argv[])
...
@@ -95,6 +95,7 @@ int main(int argc, char* argv[])
ck
::
index_t
in_left_pad_w
=
1
;
ck
::
index_t
in_left_pad_w
=
1
;
ck
::
index_t
in_right_pad_h
=
1
;
ck
::
index_t
in_right_pad_h
=
1
;
ck
::
index_t
in_right_pad_w
=
1
;
ck
::
index_t
in_right_pad_w
=
1
;
ck
::
index_t
split_k
=
1
;
if
(
argc
==
4
)
if
(
argc
==
4
)
{
{
...
@@ -102,7 +103,7 @@ int main(int argc, char* argv[])
...
@@ -102,7 +103,7 @@ int main(int argc, char* argv[])
init_method
=
std
::
stoi
(
argv
[
2
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
nrepeat
=
std
::
stoi
(
argv
[
3
]);
nrepeat
=
std
::
stoi
(
argv
[
3
]);
}
}
else
if
(
argc
==
19
)
else
if
(
argc
==
20
)
{
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
...
@@ -123,6 +124,7 @@ int main(int argc, char* argv[])
...
@@ -123,6 +124,7 @@ int main(int argc, char* argv[])
in_left_pad_w
=
std
::
stoi
(
argv
[
16
]);
in_left_pad_w
=
std
::
stoi
(
argv
[
16
]);
in_right_pad_h
=
std
::
stoi
(
argv
[
17
]);
in_right_pad_h
=
std
::
stoi
(
argv
[
17
]);
in_right_pad_w
=
std
::
stoi
(
argv
[
18
]);
in_right_pad_w
=
std
::
stoi
(
argv
[
18
]);
split_k
=
std
::
stoi
(
argv
[
19
]);
}
}
else
else
{
{
...
@@ -185,12 +187,13 @@ int main(int argc, char* argv[])
...
@@ -185,12 +187,13 @@ int main(int argc, char* argv[])
case
0
:
break
;
case
0
:
break
;
case
1
:
case
1
:
in_n_c_hi_wi
.
GenerateTensorValue
(
GeneratorTensor_2
<
InDataType
>
{
-
5
,
5
});
in_n_c_hi_wi
.
GenerateTensorValue
(
GeneratorTensor_2
<
InDataType
>
{
-
5
,
5
});
out_n_k_ho_wo
.
GenerateTensorValue
(
GeneratorTensor_2
<
Wei
DataType
>
{
-
5
,
5
});
out_n_k_ho_wo
.
GenerateTensorValue
(
GeneratorTensor_2
<
Out
DataType
>
{
-
5
,
5
});
break
;
break
;
default:
default:
in_n_c_hi_wi
.
GenerateTensorValue
(
GeneratorTensor_
3
<
InDataType
>
{
0.0
,
1.0
});
in_n_c_hi_wi
.
GenerateTensorValue
(
GeneratorTensor_
1
<
InDataType
>
{
1
});
out_n_k_ho_wo
.
GenerateTensorValue
(
GeneratorTensor_
3
<
Wei
DataType
>
{
-
0.5
,
0.5
});
out_n_k_ho_wo
.
GenerateTensorValue
(
GeneratorTensor_
1
<
Out
DataType
>
{
1
});
}
}
wei_k_c_y_x_device_result
.
GenerateTensorValue
(
GeneratorTensor_1
<
WeiDataType
>
{
0
});
DeviceMem
in_device_buf
(
sizeof
(
InDataType
)
*
in_n_c_hi_wi
.
mDesc
.
GetElementSpace
());
DeviceMem
in_device_buf
(
sizeof
(
InDataType
)
*
in_n_c_hi_wi
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_device_buf
(
sizeof
(
WeiDataType
)
*
DeviceMem
wei_device_buf
(
sizeof
(
WeiDataType
)
*
...
@@ -199,6 +202,9 @@ int main(int argc, char* argv[])
...
@@ -199,6 +202,9 @@ int main(int argc, char* argv[])
in_device_buf
.
ToDevice
(
in_n_c_hi_wi
.
mData
.
data
());
in_device_buf
.
ToDevice
(
in_n_c_hi_wi
.
mData
.
data
());
out_device_buf
.
ToDevice
(
out_n_k_ho_wo
.
mData
.
data
());
out_device_buf
.
ToDevice
(
out_n_k_ho_wo
.
mData
.
data
());
wei_device_buf
.
ToDevice
(
wei_k_c_y_x_device_result
.
mData
.
data
());
LogRangeAsType
<
float
>
(
std
::
cout
<<
"wei_device(before): "
,
wei_k_c_y_x_device_result
.
mData
,
","
)
<<
std
::
endl
;
// do GEMM
// do GEMM
auto
conv
=
DeviceConvWrWInstance
{};
auto
conv
=
DeviceConvWrWInstance
{};
...
@@ -218,7 +224,8 @@ int main(int argc, char* argv[])
...
@@ -218,7 +224,8 @@ int main(int argc, char* argv[])
input_right_pads
,
input_right_pads
,
InElementOp
{},
InElementOp
{},
WeiElementOp
{},
WeiElementOp
{},
OutElementOp
{});
OutElementOp
{},
split_k
);
if
(
!
conv
.
IsSupportedArgument
(
argument
))
if
(
!
conv
.
IsSupportedArgument
(
argument
))
{
{
...
@@ -262,6 +269,16 @@ int main(int argc, char* argv[])
...
@@ -262,6 +269,16 @@ int main(int argc, char* argv[])
wei_device_buf
.
FromDevice
(
wei_k_c_y_x_device_result
.
mData
.
data
());
wei_device_buf
.
FromDevice
(
wei_k_c_y_x_device_result
.
mData
.
data
());
if
(
1
)
{
LogRangeAsType
<
float
>
(
std
::
cout
<<
"out: "
,
out_n_k_ho_wo
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"in : "
,
in_n_c_hi_wi
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"wei_device(after): "
,
wei_k_c_y_x_device_result
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"wei_host : "
,
wei_k_c_y_x_host_result
.
mData
,
","
)
<<
std
::
endl
;
}
check_error
(
wei_k_c_y_x_host_result
,
wei_k_c_y_x_device_result
);
check_error
(
wei_k_c_y_x_host_result
,
wei_k_c_y_x_device_result
);
}
}
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment