Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
3f976dd0
Commit
3f976dd0
authored
Jan 10, 2023
by
Rosty Geyyer
Browse files
Update batch handling
parent
b9f23971
Changes
2
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
233 additions
and
129 deletions
+233
-129
example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_dl_fp16.cpp
...ouped_conv_bwd_weight/grouped_conv_bwd_weight_dl_fp16.cpp
+36
-37
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_dl.hpp
...impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_dl.hpp
+197
-92
No files found.
example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_dl_fp16.cpp
View file @
3f976dd0
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_dl.hpp"
#include <algorithm>
#include <iostream>
#include <iostream>
#include <numeric>
#include <iterator>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_dl.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
...
@@ -37,7 +36,7 @@ static constexpr auto ConvBwdWeightDefault =
...
@@ -37,7 +36,7 @@ static constexpr auto ConvBwdWeightDefault =
template
<
ck
::
index_t
NDimSpatial
>
template
<
ck
::
index_t
NDimSpatial
>
using
DeviceConvBwdWeightInstance
=
using
DeviceConvBwdWeightInstance
=
ck
::
tensor_operation
::
device
::
DeviceConv
Nd
BwdWeight
NwcKxcN
wk_Dl
<
ck
::
tensor_operation
::
device
::
Device
Grouped
ConvBwdWeight
GnwcGkxcGn
wk_Dl
<
NDimSpatial
,
// NDimSpatial
NDimSpatial
,
// NDimSpatial
InDataType
,
// InDataType
InDataType
,
// InDataType
WeiDataType
,
// WeiDataType
WeiDataType
,
// WeiDataType
...
@@ -142,7 +141,7 @@ int run_conv_bwd_weight(bool do_verification,
...
@@ -142,7 +141,7 @@ int run_conv_bwd_weight(bool do_verification,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_dilations
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_dilations
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_left_pads
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_left_pads
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_right_pads
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_right_pads
{};
auto
range_copy
=
[](
const
auto
&
from
,
auto
to
)
{
std
::
copy
(
begin
(
from
),
end
(
from
),
to
);
};
auto
range_copy
=
[](
const
auto
&
from
,
auto
to
)
{
std
::
copy
(
begin
(
from
),
end
(
from
),
to
);
};
range_copy
(
conv_param
.
input_spatial_lengths_
,
begin
(
input_spatial_lengths
));
range_copy
(
conv_param
.
input_spatial_lengths_
,
begin
(
input_spatial_lengths
));
...
@@ -295,16 +294,16 @@ int main(int argc, char* argv[])
...
@@ -295,16 +294,16 @@ int main(int argc, char* argv[])
WeiElementOp
,
WeiElementOp
,
OutElementOp
,
OutElementOp
,
DeviceConvBwdWeightInstance
<
1
>>
(
do_verification
,
DeviceConvBwdWeightInstance
<
1
>>
(
do_verification
,
init_method
,
init_method
,
time_kernel
,
time_kernel
,
conv_param
,
conv_param
,
in_g_n_c_wis_desc
,
in_g_n_c_wis_desc
,
wei_g_k_c_xs_desc
,
wei_g_k_c_xs_desc
,
out_g_n_k_wos_desc
,
out_g_n_k_wos_desc
,
in_element_op
,
in_element_op
,
wei_element_op
,
wei_element_op
,
out_element_op
,
out_element_op
,
split_k
);
split_k
);
}
}
else
if
(
conv_param
.
num_dim_spatial_
==
2
)
else
if
(
conv_param
.
num_dim_spatial_
==
2
)
{
{
...
@@ -332,16 +331,16 @@ int main(int argc, char* argv[])
...
@@ -332,16 +331,16 @@ int main(int argc, char* argv[])
WeiElementOp
,
WeiElementOp
,
OutElementOp
,
OutElementOp
,
DeviceConvBwdWeightInstance
<
2
>>
(
do_verification
,
DeviceConvBwdWeightInstance
<
2
>>
(
do_verification
,
init_method
,
init_method
,
time_kernel
,
time_kernel
,
conv_param
,
conv_param
,
in_g_n_c_wis_desc
,
in_g_n_c_wis_desc
,
wei_g_k_c_xs_desc
,
wei_g_k_c_xs_desc
,
out_g_n_k_wos_desc
,
out_g_n_k_wos_desc
,
in_element_op
,
in_element_op
,
wei_element_op
,
wei_element_op
,
out_element_op
,
out_element_op
,
split_k
);
split_k
);
}
}
else
if
(
conv_param
.
num_dim_spatial_
==
3
)
else
if
(
conv_param
.
num_dim_spatial_
==
3
)
{
{
...
@@ -369,16 +368,16 @@ int main(int argc, char* argv[])
...
@@ -369,16 +368,16 @@ int main(int argc, char* argv[])
WeiElementOp
,
WeiElementOp
,
OutElementOp
,
OutElementOp
,
DeviceConvBwdWeightInstance
<
3
>>
(
do_verification
,
DeviceConvBwdWeightInstance
<
3
>>
(
do_verification
,
init_method
,
init_method
,
time_kernel
,
time_kernel
,
conv_param
,
conv_param
,
in_g_n_c_wis_desc
,
in_g_n_c_wis_desc
,
wei_g_k_c_xs_desc
,
wei_g_k_c_xs_desc
,
out_g_n_k_wos_desc
,
out_g_n_k_wos_desc
,
in_element_op
,
in_element_op
,
wei_element_op
,
wei_element_op
,
out_element_op
,
out_element_op
,
split_k
);
split_k
);
}
}
return
0
;
return
0
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_dl.hpp
View file @
3f976dd0
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment