"vscode:/vscode.git/clone" did not exist on "7f29ed0b95962e1776df0b1f10fc6d02a4915156"
Commit 3f976dd0 authored by Rosty Geyyer's avatar Rosty Geyyer
Browse files

Update batch handling

parent b9f23971
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_dl.hpp"
#include <algorithm>
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <iterator>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_dl.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
......@@ -37,7 +36,7 @@ static constexpr auto ConvBwdWeightDefault =
template <ck::index_t NDimSpatial>
using DeviceConvBwdWeightInstance =
ck::tensor_operation::device::DeviceConvNdBwdWeightNwcKxcNwk_Dl<
ck::tensor_operation::device::DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl<
NDimSpatial, // NDimSpatial
InDataType, // InDataType
WeiDataType, // WeiDataType
......@@ -142,7 +141,7 @@ int run_conv_bwd_weight(bool do_verification,
std::array<ck::index_t, NDimSpatial> conv_filter_dilations{};
std::array<ck::index_t, NDimSpatial> input_left_pads{};
std::array<ck::index_t, NDimSpatial> input_right_pads{};
auto range_copy = [](const auto& from, auto to) { std::copy(begin(from), end(from), to); };
range_copy(conv_param.input_spatial_lengths_, begin(input_spatial_lengths));
......@@ -295,16 +294,16 @@ int main(int argc, char* argv[])
WeiElementOp,
OutElementOp,
DeviceConvBwdWeightInstance<1>>(do_verification,
init_method,
time_kernel,
conv_param,
in_g_n_c_wis_desc,
wei_g_k_c_xs_desc,
out_g_n_k_wos_desc,
in_element_op,
wei_element_op,
out_element_op,
split_k);
init_method,
time_kernel,
conv_param,
in_g_n_c_wis_desc,
wei_g_k_c_xs_desc,
out_g_n_k_wos_desc,
in_element_op,
wei_element_op,
out_element_op,
split_k);
}
else if(conv_param.num_dim_spatial_ == 2)
{
......@@ -332,16 +331,16 @@ int main(int argc, char* argv[])
WeiElementOp,
OutElementOp,
DeviceConvBwdWeightInstance<2>>(do_verification,
init_method,
time_kernel,
conv_param,
in_g_n_c_wis_desc,
wei_g_k_c_xs_desc,
out_g_n_k_wos_desc,
in_element_op,
wei_element_op,
out_element_op,
split_k);
init_method,
time_kernel,
conv_param,
in_g_n_c_wis_desc,
wei_g_k_c_xs_desc,
out_g_n_k_wos_desc,
in_element_op,
wei_element_op,
out_element_op,
split_k);
}
else if(conv_param.num_dim_spatial_ == 3)
{
......@@ -369,16 +368,16 @@ int main(int argc, char* argv[])
WeiElementOp,
OutElementOp,
DeviceConvBwdWeightInstance<3>>(do_verification,
init_method,
time_kernel,
conv_param,
in_g_n_c_wis_desc,
wei_g_k_c_xs_desc,
out_g_n_k_wos_desc,
in_element_op,
wei_element_op,
out_element_op,
split_k);
init_method,
time_kernel,
conv_param,
in_g_n_c_wis_desc,
wei_g_k_c_xs_desc,
out_g_n_k_wos_desc,
in_element_op,
wei_element_op,
out_element_op,
split_k);
}
return 0;
......
......@@ -4,6 +4,7 @@
#pragma once
#include <iostream>
#include <numeric>
#include <sstream>
#include "ck/utility/common_header.hpp"
......@@ -20,6 +21,83 @@ namespace ck {
namespace tensor_operation {
namespace device {
namespace {
struct ComputePtrOffsetOfStridedBatch
{
__host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideA_);
}
__host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideB_);
}
__host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideC_);
}
index_t BatchStrideA_;
index_t BatchStrideB_;
index_t BatchStrideC_;
};
} // namespace
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
typename AGridDesc_K0_M0_M1_K1,
typename BGridDesc_K0_N0_N1_K1,
typename CGridDesc_M0_M10_M11_N0_N10_N11,
typename DefaultBlock2CTileMap,
typename ComputePtrOffsetOfBatch,
bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_batched_gemm_dlops_bwd_weight(
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const index_t batch_count,
const AGridDesc_K0_M0_M1_K1 a_grid_desc_kbatch_k0_m0_m1_k1,
const BGridDesc_K0_N0_N1_K1 b_grid_desc_kbatch_k0_n0_n1_k1,
const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11,
const DefaultBlock2CTileMap block_2_ctile_map,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
{
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)));
__shared__ FloatAB p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB)];
GridwiseGemm::template Run<HasMainKBlockLoop>(
p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset,
p_c_grid + c_batch_offset,
p_shared,
a_grid_desc_kbatch_k0_m0_m1_k1,
b_grid_desc_kbatch_k0_n0_n1_k1,
c_grid_desc_m0_m10_m11_n0_n10_n11,
block_2_ctile_map,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
}
template <ck::index_t NDimSpatial,
typename InDataType,
typename WeiDataType,
......@@ -56,7 +134,7 @@ template <ck::index_t NDimSpatial,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector>
struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
: public DeviceGroupedConvBwdWeight<
NDimSpatial,
ck::tuple_element_t<NDimSpatial - 1,
......@@ -78,7 +156,7 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
WeiElementwiseOperation,
OutElementwiseOperation>
{
using DeviceOp = DeviceConvNdBwdWeightNwcKxcNwk_Dl;
using DeviceOp = DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl;
using ADataType = OutDataType;
using BDataType = InDataType;
......@@ -116,18 +194,18 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
static constexpr auto BBlockLdsN1Padding = 4;
template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false>
static auto
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N,
ck::index_t K,
ck::index_t C,
std::array<ck::index_t, NDimSpatial> input_spatial_lengths,
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths,
std::array<ck::index_t, NDimSpatial> output_spatial_lengths,
std::array<ck::index_t, NDimSpatial> conv_filter_strides,
std::array<ck::index_t, NDimSpatial> conv_filter_dilations,
std::array<ck::index_t, NDimSpatial> input_left_pads,
std::array<ck::index_t, NDimSpatial> input_right_pads,
ck::index_t batch_k)
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
ck::index_t N,
ck::index_t K,
ck::index_t C,
std::array<ck::index_t, NDimSpatial> input_spatial_lengths,
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths,
std::array<ck::index_t, NDimSpatial> output_spatial_lengths,
std::array<ck::index_t, NDimSpatial> conv_filter_strides,
std::array<ck::index_t, NDimSpatial> conv_filter_dilations,
std::array<ck::index_t, NDimSpatial> input_left_pads,
std::array<ck::index_t, NDimSpatial> input_right_pads,
ck::index_t batch_k)
{
using namespace ck;
......@@ -255,7 +333,7 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
// C: weight tensor
const auto wei_gemmm_gemmn_grid_desc =
......@@ -263,23 +341,23 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc);
wei_gemmm_gemmn_grid_desc);
}
} // function end
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
static auto
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N,
ck::index_t K,
ck::index_t C,
std::array<ck::index_t, NDimSpatial> input_spatial_lengths,
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths,
std::array<ck::index_t, NDimSpatial> output_spatial_lengths,
std::array<ck::index_t, NDimSpatial> conv_filter_strides,
std::array<ck::index_t, NDimSpatial> conv_filter_dilations,
std::array<ck::index_t, NDimSpatial> input_left_pads,
std::array<ck::index_t, NDimSpatial> input_right_pads,
ck::index_t batch_k)
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
ck::index_t N,
ck::index_t K,
ck::index_t C,
std::array<ck::index_t, NDimSpatial> input_spatial_lengths,
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths,
std::array<ck::index_t, NDimSpatial> output_spatial_lengths,
std::array<ck::index_t, NDimSpatial> conv_filter_strides,
std::array<ck::index_t, NDimSpatial> conv_filter_dilations,
std::array<ck::index_t, NDimSpatial> input_left_pads,
std::array<ck::index_t, NDimSpatial> input_right_pads,
ck::index_t batch_k)
{
using namespace ck;
......@@ -436,18 +514,18 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
} // function end
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
static auto
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N,
ck::index_t K,
ck::index_t C,
std::array<ck::index_t, NDimSpatial> input_spatial_lengths,
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths,
std::array<ck::index_t, NDimSpatial> output_spatial_lengths,
std::array<ck::index_t, NDimSpatial> conv_filter_strides,
std::array<ck::index_t, NDimSpatial> conv_filter_dilations,
std::array<ck::index_t, NDimSpatial> input_left_pads,
std::array<ck::index_t, NDimSpatial> input_right_pads,
ck::index_t batch_k)
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
ck::index_t N,
ck::index_t K,
ck::index_t C,
std::array<ck::index_t, NDimSpatial> input_spatial_lengths,
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths,
std::array<ck::index_t, NDimSpatial> output_spatial_lengths,
std::array<ck::index_t, NDimSpatial> conv_filter_strides,
std::array<ck::index_t, NDimSpatial> conv_filter_dilations,
std::array<ck::index_t, NDimSpatial> input_left_pads,
std::array<ck::index_t, NDimSpatial> input_right_pads,
ck::index_t batch_k)
{
using namespace ck;
......@@ -727,6 +805,8 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
a_grid_desc_kbatch_k0_m_k1_{},
b_grid_desc_kbatch_k0_n_k1_{},
c_grid_desc_m_n_{},
block_2_ctile_map_{},
compute_ptr_offset_of_batch_{},
a_element_op_{out_element_op},
b_element_op_{wei_element_op},
c_element_op_{in_element_op},
......@@ -759,12 +839,35 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
a_grid_desc_kbatch_k0_m_k1_ = descs[I0];
b_grid_desc_kbatch_k0_n_k1_ = descs[I1];
c_grid_desc_m_n_ = descs[I2];
a_grid_desc_kbatch_k0_m0_m1_k1_ = GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(a_grid_desc_kbatch_k0_m_k1_);
b_grid_desc_kbatch_k0_n0_n1_k1_ = GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(b_grid_desc_kbatch_k0_n_k1_);
c_grid_desc_m0_m10_m11_n0_n10_n11_ = GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(c_grid_desc_m_n_);
c_grid_desc_m_n_ = descs[I2];
a_grid_desc_kbatch_k0_m0_m1_k1_ =
GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(a_grid_desc_kbatch_k0_m_k1_);
b_grid_desc_kbatch_k0_n0_n1_k1_ =
GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(b_grid_desc_kbatch_k0_n_k1_);
c_grid_desc_m0_m10_m11_n0_n10_n11_ =
GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(c_grid_desc_m_n_);
block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_);
// A/B/C Batch Stride
compute_ptr_offset_of_batch_.BatchStrideA_ =
N * K *
std::accumulate(begin(output_spatial_lengths),
end(output_spatial_lengths),
index_t{1},
std::multiplies<>{});
compute_ptr_offset_of_batch_.BatchStrideB_ =
N * C *
std::accumulate(begin(input_spatial_lengths),
end(input_spatial_lengths),
index_t{1},
std::multiplies<>{});
compute_ptr_offset_of_batch_.BatchStrideC_ =
K * C *
std::accumulate(begin(filter_spatial_lengths),
end(filter_spatial_lengths),
index_t{1},
std::multiplies<>{});
}
const ADataType* p_a_grid_;
......@@ -781,11 +884,14 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
DefaultBlock2CTileMap block_2_ctile_map_;
// for computing batch offset
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_;
// element-wise op
OutElementwiseOperation a_element_op_;
WeiElementwiseOperation b_element_op_;
InElementwiseOperation c_element_op_;
// for checking IsSupportedArgument()
index_t Conv_G_;
index_t Conv_N_;
......@@ -813,25 +919,21 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
<< arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) << ", "
<< arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1) << ", "
<< arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2) << ", "
<< arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I3) << "}"
<< std::endl;
<< arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I3) << "}" << std::endl;
std::cout << "arg.b_grid_desc_kbatch_k0_n_k1_{"
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I0) << ", "
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1) << ", "
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I2) << ", "
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I3) << "}"
<< std::endl;
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I3) << "}" << std::endl;
std::cout << "arg.c_grid_desc_m_n_{ "
<< arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}"
<< std::endl;
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
}
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
ShowInfo(arg);
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_,
......@@ -845,59 +947,63 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.Conv_G_;
auto launch_kernel = [&](auto has_main_k_block_loop,
auto has_double_tail_k_block_loop) {
constexpr bool has_main_loop = has_main_k_block_loop.value;
constexpr bool has_double_loop = has_double_tail_k_block_loop;
const auto kernel = kernel_gemm_dl_v1r3<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
remove_reference_t<DeviceOp::AGridDesc_K0_M0_M1_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N0_N1_K1>,
remove_reference_t<DeviceOp::CGridDesc_M0_M10_M11_N0_N10_N11>,
remove_reference_t<DeviceOp::DefaultBlock2CTileMap>,
has_main_loop,
has_double_loop>;
return launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_kbatch_k0_m0_m1_k1_,
arg.b_grid_desc_kbatch_k0_n0_n1_k1_,
arg.c_grid_desc_m0_m10_m11_n0_n10_n11_,
arg.block_2_ctile_map_);
auto launch_kernel = [&](auto has_main_k_block_loop,
auto has_double_tail_k_block_loop) {
constexpr bool has_main_loop = has_main_k_block_loop.value;
constexpr bool has_double_loop = has_double_tail_k_block_loop;
const auto kernel = kernel_batched_gemm_dlops_bwd_weight<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
remove_reference_t<DeviceOp::AGridDesc_K0_M0_M1_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N0_N1_K1>,
remove_reference_t<DeviceOp::CGridDesc_M0_M10_M11_N0_N10_N11>,
remove_reference_t<DeviceOp::DefaultBlock2CTileMap>,
ComputePtrOffsetOfStridedBatch,
has_main_loop,
has_double_loop>;
return launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.Conv_G_,
arg.a_grid_desc_kbatch_k0_m0_m1_k1_,
arg.b_grid_desc_kbatch_k0_n0_n1_k1_,
arg.c_grid_desc_m0_m10_m11_n0_n10_n11_,
arg.block_2_ctile_map_,
arg.compute_ptr_offset_of_batch_);
};
const auto K0 = arg.a_grid_desc_kbatch_k0_m0_m1_k1_.GetLength(I1);
const auto K0 = arg.a_grid_desc_kbatch_k0_m0_m1_k1_.GetLength(I1);
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K0);
const bool has_double_tail_k_block_loop =
GridwiseGemm::CalculateHasDoubleTailKBlockLoop(K0);
if(has_main_k_block_loop && has_double_tail_k_block_loop)
{
return launch_kernel(integral_constant<bool, true>{}, integral_constant<bool, true>{});
return launch_kernel(integral_constant<bool, true>{},
integral_constant<bool, true>{});
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
return launch_kernel(integral_constant<bool, true>{},
integral_constant<bool, false>{});
integral_constant<bool, false>{});
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
return launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, true>{});
integral_constant<bool, true>{});
}
else
{
return launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, false>{});
integral_constant<bool, false>{});
}
}
......@@ -987,9 +1093,8 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
}
// Gridwise GEMM size
return GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_grid_desc_m_n_);
return GridwiseGemm::CheckValidity(
arg.a_grid_desc_kbatch_k0_m_k1_, arg.b_grid_desc_kbatch_k0_n_k1_, arg.c_grid_desc_m_n_);
}
bool IsSupportedArgument(const BaseArgument* p_arg) override
......@@ -1088,7 +1193,7 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
auto str = std::stringstream();
// clang-format off
str << "DeviceConvNdBwdWeightNwcKxcNwk_Dl"
str << "DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment