Commit 139b950f authored by Bartlomiej Kocot's avatar Bartlomiej Kocot
Browse files

Fix comments

parent e6f4653a
......@@ -101,10 +101,10 @@ template <ck::index_t NumDimSpatial,
typename WeiLayout,
typename OutLayout>
bool run_grouped_conv_bwd_weight(
ck::index_t G,
ck::index_t N,
ck::index_t K,
ck::index_t C,
const ck::index_t G,
const ck::index_t N,
const ck::index_t K,
const ck::index_t C,
const std::array<ck::index_t, NumDimSpatial>& input_spatial_lengths,
const std::array<ck::index_t, NumDimSpatial>& filter_spatial_lengths,
const std::array<ck::index_t, NumDimSpatial>& output_spatial_lengths,
......@@ -228,6 +228,8 @@ bool run_grouped_conv_bwd_weight(
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
input_strides,
output_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
......
......@@ -22,10 +22,15 @@ static constexpr ck::index_t C = 192;
static constexpr ck::index_t X = 3;
static constexpr ck::index_t Wi = 28;
static constexpr ck::index_t Wo = 28;
static constexpr std::array<ck::index_t, NumDimSpatial + 3> input_strides{
G * N * Wi * C, N* Wi* C, Wi* C, C, 1};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> output_strides{
G * N * Wo * K, N* Wo* K, Wo* K, K, 1};
static constexpr std::array<ck::index_t, NumDimSpatial> input_spatial_lengths{Wi};
static constexpr std::array<ck::index_t, NumDimSpatial> filter_spatial_lengths{X};
static constexpr std::array<ck::index_t, NumDimSpatial> output_spatial_lengths{Wo};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> input_strides{N * Wi * C, Wi* C, C, 1};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> output_strides{N * Wo * K, Wo* K, K, 1};
static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_strides{1};
static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_dilations{1};
static constexpr std::array<ck::index_t, NumDimSpatial> input_left_pads{1};
static constexpr std::array<ck::index_t, NumDimSpatial> input_right_pads{1};
int main()
{
......@@ -35,8 +40,19 @@ int main()
OutDataType,
InLayout,
WeiLayout,
OutLayout>(
G, N, K, C, {Wi}, {X}, {Wo}, input_strides, output_strides, {} {1}, {1}, {1}, {1})
OutLayout>(G,
N,
K,
C,
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
input_strides,
output_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads)
? EXIT_SUCCESS
: EXIT_FAILURE;
}
......@@ -25,10 +25,17 @@ static constexpr ck::index_t Hi = 28;
static constexpr ck::index_t Wi = 28;
static constexpr ck::index_t Ho = 28;
static constexpr ck::index_t Wo = 28;
static constexpr std::array<ck::index_t, NumDimSpatial> input_spatial_lengths{Hi, Wi};
static constexpr std::array<ck::index_t, NumDimSpatial> filter_spatial_lengths{Y, X};
static constexpr std::array<ck::index_t, NumDimSpatial> output_spatial_lengths{Ho, Wo};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> input_strides{
G * N * Hi * Wi * C, N* Hi* Wi* C, Hi* Wi* C, Wi* C, C, 1};
N * Hi * Wi * C, Hi* Wi* C, Wi* C, C, 1};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> output_strides{
G * N * Ho * Wo * K, N* Ho* Wo* K, Ho* Wo* K, Wo* K, K, 1};
N * Ho * Wo * K, Ho* Wo* K, Wo* K, K, 1};
static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_strides{1, 1};
static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_dilations{1, 1};
static constexpr std::array<ck::index_t, NumDimSpatial> input_left_pads{1, 1};
static constexpr std::array<ck::index_t, NumDimSpatial> input_right_pads{1, 1};
int main()
{
......@@ -42,15 +49,15 @@ int main()
N,
K,
C,
{Hi, Wi},
{Y, X},
{Ho, Wo},
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
input_strides,
output_strides,
{1, 1},
{1, 1},
{1, 1},
{1, 1})
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads)
? EXIT_SUCCESS
: EXIT_FAILURE;
}
......@@ -28,10 +28,17 @@ static constexpr ck::index_t Wi = 3;
static constexpr ck::index_t Do = 28;
static constexpr ck::index_t Ho = 28;
static constexpr ck::index_t Wo = 3;
static constexpr std::array<ck::index_t, NumDimSpatial> input_spatial_lengths{Di, Hi, Wi};
static constexpr std::array<ck::index_t, NumDimSpatial> filter_spatial_lengths{Z, Y, X};
static constexpr std::array<ck::index_t, NumDimSpatial> output_spatial_lengths{Do, Ho, Wo};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> input_strides{
G * N * Di * Hi * Wi * C, N* Di* Hi* Wi* C, Di* Hi* Wi* C, Hi* Wi* C, Wi* C, C, 1};
N * Di * Hi * Wi * C, Di* Hi* Wi* C, Hi* Wi* C, Wi* C, C, 1};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> output_strides{
G * N * Do * Ho * Wo * K, N* Do* Ho* Wo* K, Do* Ho* Wo* K, Ho* Wo* K, Wo* K, K, 1};
N * Do * Ho * Wo * K, Do* Ho* Wo* K, Ho* Wo* K, Wo* K, K, 1};
static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_strides{1, 1, 1};
static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_dilations{1, 1, 1};
static constexpr std::array<ck::index_t, NumDimSpatial> input_left_pads{1, 1, 1};
static constexpr std::array<ck::index_t, NumDimSpatial> input_right_pads{1, 1, 1};
int main()
{
......@@ -45,15 +52,15 @@ int main()
N,
K,
C,
{Di, Hi, Wi},
{Z, Y, X},
{Do, Ho, Wo},
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
input_strides,
output_strides,
{1, 1, 1},
{1, 1, 1},
{1, 1, 1},
{1, 1, 1})
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads)
? EXIT_SUCCESS
: EXIT_FAILURE;
}
......@@ -28,10 +28,17 @@ static constexpr ck::index_t Wi = 3;
static constexpr ck::index_t Do = 28;
static constexpr ck::index_t Ho = 28;
static constexpr ck::index_t Wo = 3;
static constexpr std::array<ck::index_t, NumDimSpatial> input_spatial_lengths{Di, Hi, Wi};
static constexpr std::array<ck::index_t, NumDimSpatial> filter_spatial_lengths{Z, Y, X};
static constexpr std::array<ck::index_t, NumDimSpatial> output_spatial_lengths{Do, Ho, Wo};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> input_strides{
G * N * Di * Hi * Wi * C, N* Di* Hi* Wi* C, Di* Hi* Wi* C, Hi* Wi* C, Wi* C, C, 1};
N * Di * Hi * Wi * C, Di* Hi* Wi* C, Hi* Wi* C, Wi* C, C, 1};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> output_strides{
G * N * Do * Ho * Wo * K, N* Do* Ho* Wo* K, Do* Ho* Wo* K, Ho* Wo* K, Wo* K, K, 1};
N * Do * Ho * Wo * K, Do* Ho* Wo* K, Ho* Wo* K, Wo* K, K, 1};
static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_strides{1, 1, 1};
static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_dilations{1, 1, 1};
static constexpr std::array<ck::index_t, NumDimSpatial> input_left_pads{1, 1, 1};
static constexpr std::array<ck::index_t, NumDimSpatial> input_right_pads{1, 1, 1};
int main()
{
......@@ -41,19 +48,20 @@ int main()
OutDataType,
InLayout,
WeiLayout,
OutLayout>(G,
N,
K,
C,
{Di, Hi, Wi},
{Z, Y, X},
{Do, Ho, Wo},
input_strides,
output_strides,
{1, 1, 1},
{1, 1, 1},
{1, 1, 1},
{1, 1, 1})
OutLayout>(
G,
N,
K,
C,
{Di, Hi, Wi},
{Z, Y, X},
{Do, Ho, Wo},
{N * Di * Hi * Wi * C, Di * Hi * Wi * C, Hi * Wi * C, Wi * C, C, 1},
{N * Do * Ho * Wo * K, Do * Ho * Wo * K, Ho * Wo * K, Wo * K, K, 1},
{1, 1, 1},
{1, 1, 1},
{1, 1, 1},
{1, 1, 1})
? EXIT_SUCCESS
: EXIT_FAILURE;
}
......@@ -18,7 +18,7 @@ using OutElementOp = PassThrough;
template <ck::index_t NDimSpatial>
using DeviceConvBwdWeightInstance =
ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Xdl_CShuffle<
NDimSpatial, // NDimSpatial
NDimSpatial,
ck::tuple_element_t<NDimSpatial - 1,
ck::Tuple<ck::tensor_layout::convolution::GNWC,
ck::tensor_layout::convolution::GNHWC,
......
......@@ -17,7 +17,7 @@ using OutElementOp = PassThrough;
template <ck::index_t NDimSpatial>
using DeviceConvBwdWeightInstance =
ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Xdl_CShuffle<
NDimSpatial, // NDimSpatial
NDimSpatial,
ck::tuple_element_t<NDimSpatial - 1,
ck::Tuple<ck::tensor_layout::convolution::GNWC,
ck::tensor_layout::convolution::GNHWC,
......
......@@ -27,19 +27,19 @@ struct DeviceGroupedConvBwdWeight : public BaseOperator
MakeArgumentPointer(const void* p_in,
void* p_wei,
const void* p_out,
ck::index_t G,
ck::index_t N,
ck::index_t K,
ck::index_t C,
std::array<ck::index_t, NDimSpatial>& input_spatial_lengths,
std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
std::array<ck::index_t, NDimSpatial + 3>& input_strides,
std::array<ck::index_t, NDimSpatial + 3>& output_strides,
std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
std::array<ck::index_t, NDimSpatial>& input_left_pads,
std::array<ck::index_t, NDimSpatial>& input_right_pads,
const ck::index_t G,
const ck::index_t N,
const ck::index_t K,
const ck::index_t C,
const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths,
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
const std::array<ck::index_t, NDimSpatial + 3>& input_strides,
const std::array<ck::index_t, NDimSpatial + 3>& output_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
const std::array<ck::index_t, NDimSpatial>& input_left_pads,
const std::array<ck::index_t, NDimSpatial>& input_right_pads,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op,
......
......@@ -195,17 +195,17 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false>
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
ck::index_t N,
ck::index_t K,
ck::index_t C,
std::array<ck::index_t, NDimSpatial> input_spatial_lengths,
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths,
std::array<ck::index_t, NDimSpatial> output_spatial_lengths,
std::array<ck::index_t, NDimSpatial> conv_filter_strides,
std::array<ck::index_t, NDimSpatial> conv_filter_dilations,
std::array<ck::index_t, NDimSpatial> input_left_pads,
std::array<ck::index_t, NDimSpatial> input_right_pads,
ck::index_t batch_k)
const ck::index_t N,
const ck::index_t K,
const ck::index_t C,
const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths,
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
const std::array<ck::index_t, NDimSpatial>& input_left_pads,
const std::array<ck::index_t, NDimSpatial>& input_right_pads,
const ck::index_t batch_k)
{
using namespace ck;
......@@ -347,17 +347,17 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
} // function end
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
ck::index_t N,
ck::index_t K,
ck::index_t C,
std::array<ck::index_t, NDimSpatial> input_spatial_lengths,
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths,
std::array<ck::index_t, NDimSpatial> output_spatial_lengths,
std::array<ck::index_t, NDimSpatial> conv_filter_strides,
std::array<ck::index_t, NDimSpatial> conv_filter_dilations,
std::array<ck::index_t, NDimSpatial> input_left_pads,
std::array<ck::index_t, NDimSpatial> input_right_pads,
ck::index_t batch_k)
const ck::index_t N,
const ck::index_t K,
const ck::index_t C,
const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths,
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
const std::array<ck::index_t, NDimSpatial>& input_left_pads,
const std::array<ck::index_t, NDimSpatial>& input_right_pads,
const ck::index_t batch_k)
{
using namespace ck;
......@@ -515,17 +515,17 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
ck::index_t N,
ck::index_t K,
ck::index_t C,
std::array<ck::index_t, NDimSpatial> input_spatial_lengths,
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths,
std::array<ck::index_t, NDimSpatial> output_spatial_lengths,
std::array<ck::index_t, NDimSpatial> conv_filter_strides,
std::array<ck::index_t, NDimSpatial> conv_filter_dilations,
std::array<ck::index_t, NDimSpatial> input_left_pads,
std::array<ck::index_t, NDimSpatial> input_right_pads,
ck::index_t batch_k)
const ck::index_t N,
const ck::index_t K,
const ck::index_t C,
const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths,
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
const std::array<ck::index_t, NDimSpatial>& input_left_pads,
const std::array<ck::index_t, NDimSpatial>& input_right_pads,
const ck::index_t batch_k)
{
using namespace ck;
......@@ -784,19 +784,19 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
Argument(const InDataType* p_in_grid,
WeiDataType* p_wei_grid,
const OutDataType* p_out_grid,
ck::index_t G,
ck::index_t N,
ck::index_t K,
ck::index_t C,
std::array<ck::index_t, NDimSpatial>& input_spatial_lengths,
std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
std::array<ck::index_t, NDimSpatial + 3>& /*input_strides*/,
std::array<ck::index_t, NDimSpatial + 3>& /*output_strides*/,
std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
std::array<ck::index_t, NDimSpatial>& input_left_pads,
std::array<ck::index_t, NDimSpatial>& input_right_pads,
const ck::index_t G,
const ck::index_t N,
const ck::index_t K,
const ck::index_t C,
const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths,
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
const std::array<ck::index_t, NDimSpatial + 3>& /*input_strides*/,
const std::array<ck::index_t, NDimSpatial + 3>& /*output_strides*/,
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
const std::array<ck::index_t, NDimSpatial>& input_left_pads,
const std::array<ck::index_t, NDimSpatial>& input_right_pads,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op,
......@@ -899,18 +899,18 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
InElementwiseOperation c_element_op_;
// for checking IsSupportedArgument()
index_t Conv_G_;
index_t Conv_N_;
index_t Conv_K_;
index_t Conv_C_;
std::array<ck::index_t, NDimSpatial> input_spatial_lengths_;
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths_;
std::array<ck::index_t, NDimSpatial> output_spatial_lengths_;
std::array<ck::index_t, NDimSpatial> conv_filter_strides_;
std::array<ck::index_t, NDimSpatial> conv_filter_dilations_;
std::array<ck::index_t, NDimSpatial> input_left_pads_;
std::array<ck::index_t, NDimSpatial> input_right_pads_;
const index_t Conv_G_;
const index_t Conv_N_;
const index_t Conv_K_;
const index_t Conv_C_;
const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths_;
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths_;
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths_;
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides_;
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations_;
const std::array<ck::index_t, NDimSpatial>& input_left_pads_;
const std::array<ck::index_t, NDimSpatial>& input_right_pads_;
index_t k_batch_;
};
......@@ -1113,19 +1113,19 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
static auto MakeArgument(const InDataType* p_in_grid,
WeiDataType* p_wei_grid,
const OutDataType* p_out_grid,
ck::index_t G,
ck::index_t N,
ck::index_t K,
ck::index_t C,
std::array<ck::index_t, NDimSpatial>& input_spatial_lengths,
std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
std::array<ck::index_t, NDimSpatial + 3>& input_strides,
std::array<ck::index_t, NDimSpatial + 3>& output_strides,
std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
std::array<ck::index_t, NDimSpatial>& input_left_pads,
std::array<ck::index_t, NDimSpatial>& input_right_pads,
const ck::index_t G,
const ck::index_t N,
const ck::index_t K,
const ck::index_t C,
const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths,
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
const std::array<ck::index_t, NDimSpatial + 3>& input_strides,
const std::array<ck::index_t, NDimSpatial + 3>& output_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
const std::array<ck::index_t, NDimSpatial>& input_left_pads,
const std::array<ck::index_t, NDimSpatial>& input_right_pads,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op,
......@@ -1159,19 +1159,19 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
MakeArgumentPointer(const void* p_in_grid,
void* p_wei_grid,
const void* p_out_grid,
ck::index_t G,
ck::index_t N,
ck::index_t K,
ck::index_t C,
std::array<ck::index_t, NDimSpatial>& input_spatial_lengths,
std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
std::array<ck::index_t, NDimSpatial + 3>& input_strides,
std::array<ck::index_t, NDimSpatial + 3>& output_strides,
std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
std::array<ck::index_t, NDimSpatial>& input_left_pads,
std::array<ck::index_t, NDimSpatial>& input_right_pads,
const ck::index_t G,
const ck::index_t N,
const ck::index_t K,
const ck::index_t C,
const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths,
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
const std::array<ck::index_t, NDimSpatial + 3>& input_strides,
const std::array<ck::index_t, NDimSpatial + 3>& output_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
const std::array<ck::index_t, NDimSpatial>& input_left_pads,
const std::array<ck::index_t, NDimSpatial>& input_right_pads,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op,
......
......@@ -1086,21 +1086,21 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
Argument(const InDataType* p_in_grid,
WeiDataType* p_wei_grid,
const OutDataType* p_out_grid,
ck::index_t G,
ck::index_t N,
ck::index_t K,
ck::index_t C,
std::array<ck::index_t, NDimSpatial>& input_spatial_lengths,
std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
std::array<ck::index_t, NDimSpatial + 3>& input_strides,
std::array<ck::index_t, NDimSpatial + 3>& output_strides,
std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
std::array<ck::index_t, NDimSpatial>& input_left_pads,
std::array<ck::index_t, NDimSpatial>& input_right_pads,
ck::index_t M01,
ck::index_t N01,
const ck::index_t G,
const ck::index_t N,
const ck::index_t K,
const ck::index_t C,
const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths,
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
const std::array<ck::index_t, NDimSpatial + 3>& input_strides,
const std::array<ck::index_t, NDimSpatial + 3>& output_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
const std::array<ck::index_t, NDimSpatial>& input_left_pads,
const std::array<ck::index_t, NDimSpatial>& input_right_pads,
const ck::index_t M01,
const ck::index_t N01,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op,
......@@ -1194,16 +1194,16 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
WeiElementwiseOperation c_element_op_;
// for checking IsSupportedArgument()
index_t Conv_G_;
index_t Conv_N_;
index_t Conv_K_;
index_t Conv_C_;
std::array<ck::index_t, NDimSpatial>& output_spatial_lengths_;
std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths_;
std::array<ck::index_t, NDimSpatial>& conv_filter_strides_;
std::array<ck::index_t, NDimSpatial>& input_left_pads_;
std::array<ck::index_t, NDimSpatial>& input_right_pads_;
index_t k_batch_;
const index_t Conv_G_;
const index_t Conv_N_;
const index_t Conv_K_;
const index_t Conv_C_;
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths_;
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths_;
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides_;
const std::array<ck::index_t, NDimSpatial>& input_left_pads_;
const std::array<ck::index_t, NDimSpatial>& input_right_pads_;
const index_t k_batch_;
};
// Invoker
......@@ -1390,23 +1390,23 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
static auto MakeArgument(const InDataType* p_in_grid,
WeiDataType* p_wei_grid,
const OutDataType* p_out_grid,
ck::index_t G,
ck::index_t N,
ck::index_t K,
ck::index_t C,
std::array<ck::index_t, NDimSpatial>& input_spatial_lengths,
std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
std::array<ck::index_t, NDimSpatial + 3>& input_strides,
std::array<ck::index_t, NDimSpatial + 3>& output_strides,
std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
std::array<ck::index_t, NDimSpatial>& input_left_pads,
std::array<ck::index_t, NDimSpatial>& input_right_pads,
const ck::index_t G,
const ck::index_t N,
const ck::index_t K,
const ck::index_t C,
const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths,
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
const std::array<ck::index_t, NDimSpatial + 3>& input_strides,
const std::array<ck::index_t, NDimSpatial + 3>& output_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
const std::array<ck::index_t, NDimSpatial>& input_left_pads,
const std::array<ck::index_t, NDimSpatial>& input_right_pads,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op,
ck::index_t split_k)
const ck::index_t split_k)
{
return Argument{p_in_grid,
p_wei_grid,
......@@ -1438,23 +1438,23 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
MakeArgumentPointer(const void* p_in_grid,
void* p_wei_grid,
const void* p_out_grid,
ck::index_t G,
ck::index_t N,
ck::index_t K,
ck::index_t C,
std::array<ck::index_t, NDimSpatial>& input_spatial_lengths,
std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
std::array<ck::index_t, NDimSpatial + 3>& input_strides,
std::array<ck::index_t, NDimSpatial + 3>& output_strides,
std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
std::array<ck::index_t, NDimSpatial>& input_left_pads,
std::array<ck::index_t, NDimSpatial>& input_right_pads,
const ck::index_t G,
const ck::index_t N,
const ck::index_t K,
const ck::index_t C,
const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths,
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
const std::array<ck::index_t, NDimSpatial + 3>& input_strides,
const std::array<ck::index_t, NDimSpatial + 3>& output_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
const std::array<ck::index_t, NDimSpatial>& input_left_pads,
const std::array<ck::index_t, NDimSpatial>& input_right_pads,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op,
ck::index_t split_k) override
const ck::index_t split_k) override
{
return std::make_unique<Argument>(static_cast<const InDataType*>(p_in_grid),
static_cast<WeiDataType*>(p_wei_grid),
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp"
......
......@@ -9,6 +9,7 @@
#include <gtest/gtest.h>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "profiler/profile_grouped_conv_bwd_weight_impl.hpp"
......@@ -23,11 +24,11 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
using InLayout = std::tuple_element_t<3, Tuple>;
using WeiLayout = std::tuple_element_t<4, Tuple>;
using OutLayout = std::tuple_element_t<5, Tuple>;
using NDimSpatial = std::tuple_element_t<6, Tuple>;
std::vector<ck::utils::conv::ConvParam> conv_params;
ck::index_t split_k{2};
template <ck::index_t NDimSpatial>
void Run()
{
EXPECT_FALSE(conv_params.empty());
......@@ -35,7 +36,7 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
for(auto& param : conv_params)
{
pass = pass && ck::profiler::profile_grouped_conv_bwd_weight_impl<NDimSpatial,
pass = pass && ck::profiler::profile_grouped_conv_bwd_weight_impl<NDimSpatial{},
InLayout,
WeiLayout,
OutLayout,
......@@ -70,21 +71,21 @@ class TestGroupedConvndBwdWeight3d : public TestGroupedConvndBwdWeight<Tuple>
using namespace ck::tensor_layout::convolution;
using KernelTypes1d =
::testing::Types<std::tuple<float, float, float, GNWC, GKXC, GNWK>,
std::tuple<ck::half_t, ck::half_t, ck::half_t, GNWC, GKXC, GNWK>,
std::tuple<ck::bhalf_t, float, ck::bhalf_t, GNWC, GKXC, GNWK>>;
using KernelTypes2d =
::testing::Types<std::tuple<float, float, float, GNHWC, GKYXC, GNHWK>,
std::tuple<ck::half_t, ck::half_t, ck::half_t, GNHWC, GKYXC, GNHWK>,
std::tuple<ck::bhalf_t, float, ck::bhalf_t, GNHWC, GKYXC, GNHWK>,
std::tuple<float, float, float, NHWGC, GKYXC, NHWGK>,
std::tuple<ck::half_t, ck::half_t, ck::half_t, NHWGC, GKYXC, NHWGK>,
std::tuple<ck::bhalf_t, float, ck::bhalf_t, NHWGC, GKYXC, NHWGK>>;
using KernelTypes3d =
::testing::Types<std::tuple<float, float, float, GNDHWC, GKZYXC, GNDHWK>,
std::tuple<ck::half_t, ck::half_t, ck::half_t, GNDHWC, GKZYXC, GNDHWK>,
std::tuple<ck::bhalf_t, float, ck::bhalf_t, GNDHWC, GKZYXC, GNDHWK>>;
using KernelTypes1d = ::testing::Types<
std::tuple<float, float, float, GNWC, GKXC, GNWK, ck::Number<1>>,
std::tuple<ck::half_t, ck::half_t, ck::half_t, GNWC, GKXC, GNWK, ck::Number<1>>,
std::tuple<ck::bhalf_t, float, ck::bhalf_t, GNWC, GKXC, GNWK, ck::Number<1>>>;
using KernelTypes2d = ::testing::Types<
std::tuple<float, float, float, GNHWC, GKYXC, GNHWK, ck::Number<2>>,
std::tuple<ck::half_t, ck::half_t, ck::half_t, GNHWC, GKYXC, GNHWK, ck::Number<2>>,
std::tuple<ck::bhalf_t, float, ck::bhalf_t, GNHWC, GKYXC, GNHWK, ck::Number<2>>,
std::tuple<float, float, float, NHWGC, GKYXC, NHWGK, ck::Number<2>>,
std::tuple<ck::half_t, ck::half_t, ck::half_t, NHWGC, GKYXC, NHWGK, ck::Number<2>>,
std::tuple<ck::bhalf_t, float, ck::bhalf_t, NHWGC, GKYXC, NHWGK, ck::Number<2>>>;
using KernelTypes3d = ::testing::Types<
std::tuple<float, float, float, GNDHWC, GKZYXC, GNDHWK, ck::Number<3>>,
std::tuple<ck::half_t, ck::half_t, ck::half_t, GNDHWC, GKZYXC, GNDHWK, ck::Number<3>>,
std::tuple<ck::bhalf_t, float, ck::bhalf_t, GNDHWC, GKZYXC, GNDHWK, ck::Number<3>>>;
TYPED_TEST_SUITE(TestGroupedConvndBwdWeight1d, KernelTypes1d);
TYPED_TEST_SUITE(TestGroupedConvndBwdWeight2d, KernelTypes2d);
......@@ -96,7 +97,7 @@ TYPED_TEST(TestGroupedConvndBwdWeight1d, Test1D)
this->conv_params.push_back({1, 2, 128, 128, 256, {1}, {14}, {2}, {1}, {0}, {0}});
this->conv_params.push_back({1, 2, 32, 128, 256, {3}, {28}, {1}, {1}, {1}, {1}});
this->conv_params.push_back({1, 2, 128, 128, 256, {1}, {3}, {1}, {1}, {0}, {0}});
this->template Run<1>();
this->Run();
}
TYPED_TEST(TestGroupedConvndBwdWeight2d, Test2D)
......@@ -108,7 +109,7 @@ TYPED_TEST(TestGroupedConvndBwdWeight2d, Test2D)
{2, 2, 4, 128, 256, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
this->conv_params.push_back(
{2, 2, 128, 128, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});
this->template Run<2>();
this->Run();
}
TYPED_TEST(TestGroupedConvndBwdWeight3d, Test3D)
......@@ -120,5 +121,5 @@ TYPED_TEST(TestGroupedConvndBwdWeight3d, Test3D)
{3, 2, 2, 128, 256, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
this->conv_params.push_back(
{3, 2, 32, 128, 256, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
this->template Run<3>();
this->Run();
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment