Commit 3f976dd0 authored by Rosty Geyyer's avatar Rosty Geyyer
Browse files

Update batch handling

parent b9f23971
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_dl.hpp" #include <algorithm>
#include <iostream> #include <iostream>
#include <numeric> #include <iterator>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_dl.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
...@@ -37,7 +36,7 @@ static constexpr auto ConvBwdWeightDefault = ...@@ -37,7 +36,7 @@ static constexpr auto ConvBwdWeightDefault =
template <ck::index_t NDimSpatial> template <ck::index_t NDimSpatial>
using DeviceConvBwdWeightInstance = using DeviceConvBwdWeightInstance =
ck::tensor_operation::device::DeviceConvNdBwdWeightNwcKxcNwk_Dl< ck::tensor_operation::device::DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl<
NDimSpatial, // NDimSpatial NDimSpatial, // NDimSpatial
InDataType, // InDataType InDataType, // InDataType
WeiDataType, // WeiDataType WeiDataType, // WeiDataType
...@@ -142,7 +141,7 @@ int run_conv_bwd_weight(bool do_verification, ...@@ -142,7 +141,7 @@ int run_conv_bwd_weight(bool do_verification,
std::array<ck::index_t, NDimSpatial> conv_filter_dilations{}; std::array<ck::index_t, NDimSpatial> conv_filter_dilations{};
std::array<ck::index_t, NDimSpatial> input_left_pads{}; std::array<ck::index_t, NDimSpatial> input_left_pads{};
std::array<ck::index_t, NDimSpatial> input_right_pads{}; std::array<ck::index_t, NDimSpatial> input_right_pads{};
auto range_copy = [](const auto& from, auto to) { std::copy(begin(from), end(from), to); }; auto range_copy = [](const auto& from, auto to) { std::copy(begin(from), end(from), to); };
range_copy(conv_param.input_spatial_lengths_, begin(input_spatial_lengths)); range_copy(conv_param.input_spatial_lengths_, begin(input_spatial_lengths));
...@@ -295,16 +294,16 @@ int main(int argc, char* argv[]) ...@@ -295,16 +294,16 @@ int main(int argc, char* argv[])
WeiElementOp, WeiElementOp,
OutElementOp, OutElementOp,
DeviceConvBwdWeightInstance<1>>(do_verification, DeviceConvBwdWeightInstance<1>>(do_verification,
init_method, init_method,
time_kernel, time_kernel,
conv_param, conv_param,
in_g_n_c_wis_desc, in_g_n_c_wis_desc,
wei_g_k_c_xs_desc, wei_g_k_c_xs_desc,
out_g_n_k_wos_desc, out_g_n_k_wos_desc,
in_element_op, in_element_op,
wei_element_op, wei_element_op,
out_element_op, out_element_op,
split_k); split_k);
} }
else if(conv_param.num_dim_spatial_ == 2) else if(conv_param.num_dim_spatial_ == 2)
{ {
...@@ -332,16 +331,16 @@ int main(int argc, char* argv[]) ...@@ -332,16 +331,16 @@ int main(int argc, char* argv[])
WeiElementOp, WeiElementOp,
OutElementOp, OutElementOp,
DeviceConvBwdWeightInstance<2>>(do_verification, DeviceConvBwdWeightInstance<2>>(do_verification,
init_method, init_method,
time_kernel, time_kernel,
conv_param, conv_param,
in_g_n_c_wis_desc, in_g_n_c_wis_desc,
wei_g_k_c_xs_desc, wei_g_k_c_xs_desc,
out_g_n_k_wos_desc, out_g_n_k_wos_desc,
in_element_op, in_element_op,
wei_element_op, wei_element_op,
out_element_op, out_element_op,
split_k); split_k);
} }
else if(conv_param.num_dim_spatial_ == 3) else if(conv_param.num_dim_spatial_ == 3)
{ {
...@@ -369,16 +368,16 @@ int main(int argc, char* argv[]) ...@@ -369,16 +368,16 @@ int main(int argc, char* argv[])
WeiElementOp, WeiElementOp,
OutElementOp, OutElementOp,
DeviceConvBwdWeightInstance<3>>(do_verification, DeviceConvBwdWeightInstance<3>>(do_verification,
init_method, init_method,
time_kernel, time_kernel,
conv_param, conv_param,
in_g_n_c_wis_desc, in_g_n_c_wis_desc,
wei_g_k_c_xs_desc, wei_g_k_c_xs_desc,
out_g_n_k_wos_desc, out_g_n_k_wos_desc,
in_element_op, in_element_op,
wei_element_op, wei_element_op,
out_element_op, out_element_op,
split_k); split_k);
} }
return 0; return 0;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment