Commit 139b950f authored by Bartlomiej Kocot's avatar Bartlomiej Kocot
Browse files

Fix comments

parent e6f4653a
...@@ -101,10 +101,10 @@ template <ck::index_t NumDimSpatial, ...@@ -101,10 +101,10 @@ template <ck::index_t NumDimSpatial,
typename WeiLayout, typename WeiLayout,
typename OutLayout> typename OutLayout>
bool run_grouped_conv_bwd_weight( bool run_grouped_conv_bwd_weight(
ck::index_t G, const ck::index_t G,
ck::index_t N, const ck::index_t N,
ck::index_t K, const ck::index_t K,
ck::index_t C, const ck::index_t C,
const std::array<ck::index_t, NumDimSpatial>& input_spatial_lengths, const std::array<ck::index_t, NumDimSpatial>& input_spatial_lengths,
const std::array<ck::index_t, NumDimSpatial>& filter_spatial_lengths, const std::array<ck::index_t, NumDimSpatial>& filter_spatial_lengths,
const std::array<ck::index_t, NumDimSpatial>& output_spatial_lengths, const std::array<ck::index_t, NumDimSpatial>& output_spatial_lengths,
...@@ -228,6 +228,8 @@ bool run_grouped_conv_bwd_weight( ...@@ -228,6 +228,8 @@ bool run_grouped_conv_bwd_weight(
input_spatial_lengths, input_spatial_lengths,
filter_spatial_lengths, filter_spatial_lengths,
output_spatial_lengths, output_spatial_lengths,
input_strides,
output_strides,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
......
...@@ -22,10 +22,15 @@ static constexpr ck::index_t C = 192; ...@@ -22,10 +22,15 @@ static constexpr ck::index_t C = 192;
static constexpr ck::index_t X = 3; static constexpr ck::index_t X = 3;
static constexpr ck::index_t Wi = 28; static constexpr ck::index_t Wi = 28;
static constexpr ck::index_t Wo = 28; static constexpr ck::index_t Wo = 28;
static constexpr std::array<ck::index_t, NumDimSpatial + 3> input_strides{ static constexpr std::array<ck::index_t, NumDimSpatial> input_spatial_lengths{Wi};
G * N * Wi * C, N* Wi* C, Wi* C, C, 1}; static constexpr std::array<ck::index_t, NumDimSpatial> filter_spatial_lengths{X};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> output_strides{ static constexpr std::array<ck::index_t, NumDimSpatial> output_spatial_lengths{Wo};
G * N * Wo * K, N* Wo* K, Wo* K, K, 1}; static constexpr std::array<ck::index_t, NumDimSpatial + 3> input_strides{N * Wi * C, Wi* C, C, 1};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> output_strides{N * Wo * K, Wo* K, K, 1};
static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_strides{1};
static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_dilations{1};
static constexpr std::array<ck::index_t, NumDimSpatial> input_left_pads{1};
static constexpr std::array<ck::index_t, NumDimSpatial> input_right_pads{1};
int main() int main()
{ {
...@@ -35,8 +40,19 @@ int main() ...@@ -35,8 +40,19 @@ int main()
OutDataType, OutDataType,
InLayout, InLayout,
WeiLayout, WeiLayout,
OutLayout>( OutLayout>(G,
G, N, K, C, {Wi}, {X}, {Wo}, input_strides, output_strides, {} {1}, {1}, {1}, {1}) N,
K,
C,
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
input_strides,
output_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads)
? EXIT_SUCCESS ? EXIT_SUCCESS
: EXIT_FAILURE; : EXIT_FAILURE;
} }
...@@ -25,10 +25,17 @@ static constexpr ck::index_t Hi = 28; ...@@ -25,10 +25,17 @@ static constexpr ck::index_t Hi = 28;
static constexpr ck::index_t Wi = 28; static constexpr ck::index_t Wi = 28;
static constexpr ck::index_t Ho = 28; static constexpr ck::index_t Ho = 28;
static constexpr ck::index_t Wo = 28; static constexpr ck::index_t Wo = 28;
static constexpr std::array<ck::index_t, NumDimSpatial> input_spatial_lengths{Hi, Wi};
static constexpr std::array<ck::index_t, NumDimSpatial> filter_spatial_lengths{Y, X};
static constexpr std::array<ck::index_t, NumDimSpatial> output_spatial_lengths{Ho, Wo};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> input_strides{ static constexpr std::array<ck::index_t, NumDimSpatial + 3> input_strides{
G * N * Hi * Wi * C, N* Hi* Wi* C, Hi* Wi* C, Wi* C, C, 1}; N * Hi * Wi * C, Hi* Wi* C, Wi* C, C, 1};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> output_strides{ static constexpr std::array<ck::index_t, NumDimSpatial + 3> output_strides{
G * N * Ho * Wo * K, N* Ho* Wo* K, Ho* Wo* K, Wo* K, K, 1}; N * Ho * Wo * K, Ho* Wo* K, Wo* K, K, 1};
static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_strides{1, 1};
static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_dilations{1, 1};
static constexpr std::array<ck::index_t, NumDimSpatial> input_left_pads{1, 1};
static constexpr std::array<ck::index_t, NumDimSpatial> input_right_pads{1, 1};
int main() int main()
{ {
...@@ -42,15 +49,15 @@ int main() ...@@ -42,15 +49,15 @@ int main()
N, N,
K, K,
C, C,
{Hi, Wi}, input_spatial_lengths,
{Y, X}, filter_spatial_lengths,
{Ho, Wo}, output_spatial_lengths,
input_strides, input_strides,
output_strides, output_strides,
{1, 1}, conv_filter_strides,
{1, 1}, conv_filter_dilations,
{1, 1}, input_left_pads,
{1, 1}) input_right_pads)
? EXIT_SUCCESS ? EXIT_SUCCESS
: EXIT_FAILURE; : EXIT_FAILURE;
} }
...@@ -28,10 +28,17 @@ static constexpr ck::index_t Wi = 3; ...@@ -28,10 +28,17 @@ static constexpr ck::index_t Wi = 3;
static constexpr ck::index_t Do = 28; static constexpr ck::index_t Do = 28;
static constexpr ck::index_t Ho = 28; static constexpr ck::index_t Ho = 28;
static constexpr ck::index_t Wo = 3; static constexpr ck::index_t Wo = 3;
static constexpr std::array<ck::index_t, NumDimSpatial> input_spatial_lengths{Di, Hi, Wi};
static constexpr std::array<ck::index_t, NumDimSpatial> filter_spatial_lengths{Z, Y, X};
static constexpr std::array<ck::index_t, NumDimSpatial> output_spatial_lengths{Do, Ho, Wo};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> input_strides{ static constexpr std::array<ck::index_t, NumDimSpatial + 3> input_strides{
G * N * Di * Hi * Wi * C, N* Di* Hi* Wi* C, Di* Hi* Wi* C, Hi* Wi* C, Wi* C, C, 1}; N * Di * Hi * Wi * C, Di* Hi* Wi* C, Hi* Wi* C, Wi* C, C, 1};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> output_strides{ static constexpr std::array<ck::index_t, NumDimSpatial + 3> output_strides{
G * N * Do * Ho * Wo * K, N* Do* Ho* Wo* K, Do* Ho* Wo* K, Ho* Wo* K, Wo* K, K, 1}; N * Do * Ho * Wo * K, Do* Ho* Wo* K, Ho* Wo* K, Wo* K, K, 1};
static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_strides{1, 1, 1};
static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_dilations{1, 1, 1};
static constexpr std::array<ck::index_t, NumDimSpatial> input_left_pads{1, 1, 1};
static constexpr std::array<ck::index_t, NumDimSpatial> input_right_pads{1, 1, 1};
int main() int main()
{ {
...@@ -45,15 +52,15 @@ int main() ...@@ -45,15 +52,15 @@ int main()
N, N,
K, K,
C, C,
{Di, Hi, Wi}, input_spatial_lengths,
{Z, Y, X}, filter_spatial_lengths,
{Do, Ho, Wo}, output_spatial_lengths,
input_strides, input_strides,
output_strides, output_strides,
{1, 1, 1}, conv_filter_strides,
{1, 1, 1}, conv_filter_dilations,
{1, 1, 1}, input_left_pads,
{1, 1, 1}) input_right_pads)
? EXIT_SUCCESS ? EXIT_SUCCESS
: EXIT_FAILURE; : EXIT_FAILURE;
} }
...@@ -28,10 +28,17 @@ static constexpr ck::index_t Wi = 3; ...@@ -28,10 +28,17 @@ static constexpr ck::index_t Wi = 3;
static constexpr ck::index_t Do = 28; static constexpr ck::index_t Do = 28;
static constexpr ck::index_t Ho = 28; static constexpr ck::index_t Ho = 28;
static constexpr ck::index_t Wo = 3; static constexpr ck::index_t Wo = 3;
static constexpr std::array<ck::index_t, NumDimSpatial> input_spatial_lengths{Di, Hi, Wi};
static constexpr std::array<ck::index_t, NumDimSpatial> filter_spatial_lengths{Z, Y, X};
static constexpr std::array<ck::index_t, NumDimSpatial> output_spatial_lengths{Do, Ho, Wo};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> input_strides{ static constexpr std::array<ck::index_t, NumDimSpatial + 3> input_strides{
G * N * Di * Hi * Wi * C, N* Di* Hi* Wi* C, Di* Hi* Wi* C, Hi* Wi* C, Wi* C, C, 1}; N * Di * Hi * Wi * C, Di* Hi* Wi* C, Hi* Wi* C, Wi* C, C, 1};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> output_strides{ static constexpr std::array<ck::index_t, NumDimSpatial + 3> output_strides{
G * N * Do * Ho * Wo * K, N* Do* Ho* Wo* K, Do* Ho* Wo* K, Ho* Wo* K, Wo* K, K, 1}; N * Do * Ho * Wo * K, Do* Ho* Wo* K, Ho* Wo* K, Wo* K, K, 1};
static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_strides{1, 1, 1};
static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_dilations{1, 1, 1};
static constexpr std::array<ck::index_t, NumDimSpatial> input_left_pads{1, 1, 1};
static constexpr std::array<ck::index_t, NumDimSpatial> input_right_pads{1, 1, 1};
int main() int main()
{ {
...@@ -41,19 +48,20 @@ int main() ...@@ -41,19 +48,20 @@ int main()
OutDataType, OutDataType,
InLayout, InLayout,
WeiLayout, WeiLayout,
OutLayout>(G, OutLayout>(
N, G,
K, N,
C, K,
{Di, Hi, Wi}, C,
{Z, Y, X}, {Di, Hi, Wi},
{Do, Ho, Wo}, {Z, Y, X},
input_strides, {Do, Ho, Wo},
output_strides, {N * Di * Hi * Wi * C, Di * Hi * Wi * C, Hi * Wi * C, Wi * C, C, 1},
{1, 1, 1}, {N * Do * Ho * Wo * K, Do * Ho * Wo * K, Ho * Wo * K, Wo * K, K, 1},
{1, 1, 1}, {1, 1, 1},
{1, 1, 1}, {1, 1, 1},
{1, 1, 1}) {1, 1, 1},
{1, 1, 1})
? EXIT_SUCCESS ? EXIT_SUCCESS
: EXIT_FAILURE; : EXIT_FAILURE;
} }
...@@ -18,7 +18,7 @@ using OutElementOp = PassThrough; ...@@ -18,7 +18,7 @@ using OutElementOp = PassThrough;
template <ck::index_t NDimSpatial> template <ck::index_t NDimSpatial>
using DeviceConvBwdWeightInstance = using DeviceConvBwdWeightInstance =
ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Xdl_CShuffle< ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Xdl_CShuffle<
NDimSpatial, // NDimSpatial NDimSpatial,
ck::tuple_element_t<NDimSpatial - 1, ck::tuple_element_t<NDimSpatial - 1,
ck::Tuple<ck::tensor_layout::convolution::GNWC, ck::Tuple<ck::tensor_layout::convolution::GNWC,
ck::tensor_layout::convolution::GNHWC, ck::tensor_layout::convolution::GNHWC,
......
...@@ -17,7 +17,7 @@ using OutElementOp = PassThrough; ...@@ -17,7 +17,7 @@ using OutElementOp = PassThrough;
template <ck::index_t NDimSpatial> template <ck::index_t NDimSpatial>
using DeviceConvBwdWeightInstance = using DeviceConvBwdWeightInstance =
ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Xdl_CShuffle< ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Xdl_CShuffle<
NDimSpatial, // NDimSpatial NDimSpatial,
ck::tuple_element_t<NDimSpatial - 1, ck::tuple_element_t<NDimSpatial - 1,
ck::Tuple<ck::tensor_layout::convolution::GNWC, ck::Tuple<ck::tensor_layout::convolution::GNWC,
ck::tensor_layout::convolution::GNHWC, ck::tensor_layout::convolution::GNHWC,
......
...@@ -27,19 +27,19 @@ struct DeviceGroupedConvBwdWeight : public BaseOperator ...@@ -27,19 +27,19 @@ struct DeviceGroupedConvBwdWeight : public BaseOperator
MakeArgumentPointer(const void* p_in, MakeArgumentPointer(const void* p_in,
void* p_wei, void* p_wei,
const void* p_out, const void* p_out,
ck::index_t G, const ck::index_t G,
ck::index_t N, const ck::index_t N,
ck::index_t K, const ck::index_t K,
ck::index_t C, const ck::index_t C,
std::array<ck::index_t, NDimSpatial>& input_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths,
std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
std::array<ck::index_t, NDimSpatial>& output_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
std::array<ck::index_t, NDimSpatial + 3>& input_strides, const std::array<ck::index_t, NDimSpatial + 3>& input_strides,
std::array<ck::index_t, NDimSpatial + 3>& output_strides, const std::array<ck::index_t, NDimSpatial + 3>& output_strides,
std::array<ck::index_t, NDimSpatial>& conv_filter_strides, const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
std::array<ck::index_t, NDimSpatial>& conv_filter_dilations, const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
std::array<ck::index_t, NDimSpatial>& input_left_pads, const std::array<ck::index_t, NDimSpatial>& input_left_pads,
std::array<ck::index_t, NDimSpatial>& input_right_pads, const std::array<ck::index_t, NDimSpatial>& input_right_pads,
InElementwiseOperation in_element_op, InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op, WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op, OutElementwiseOperation out_element_op,
......
...@@ -195,17 +195,17 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -195,17 +195,17 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false> template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false>
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
ck::index_t N, const ck::index_t N,
ck::index_t K, const ck::index_t K,
ck::index_t C, const ck::index_t C,
std::array<ck::index_t, NDimSpatial> input_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths,
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
std::array<ck::index_t, NDimSpatial> output_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
std::array<ck::index_t, NDimSpatial> conv_filter_strides, const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
std::array<ck::index_t, NDimSpatial> conv_filter_dilations, const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
std::array<ck::index_t, NDimSpatial> input_left_pads, const std::array<ck::index_t, NDimSpatial>& input_left_pads,
std::array<ck::index_t, NDimSpatial> input_right_pads, const std::array<ck::index_t, NDimSpatial>& input_right_pads,
ck::index_t batch_k) const ck::index_t batch_k)
{ {
using namespace ck; using namespace ck;
...@@ -347,17 +347,17 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -347,17 +347,17 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
} // function end } // function end
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false> template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
ck::index_t N, const ck::index_t N,
ck::index_t K, const ck::index_t K,
ck::index_t C, const ck::index_t C,
std::array<ck::index_t, NDimSpatial> input_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths,
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
std::array<ck::index_t, NDimSpatial> output_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
std::array<ck::index_t, NDimSpatial> conv_filter_strides, const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
std::array<ck::index_t, NDimSpatial> conv_filter_dilations, const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
std::array<ck::index_t, NDimSpatial> input_left_pads, const std::array<ck::index_t, NDimSpatial>& input_left_pads,
std::array<ck::index_t, NDimSpatial> input_right_pads, const std::array<ck::index_t, NDimSpatial>& input_right_pads,
ck::index_t batch_k) const ck::index_t batch_k)
{ {
using namespace ck; using namespace ck;
...@@ -515,17 +515,17 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -515,17 +515,17 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false> template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
ck::index_t N, const ck::index_t N,
ck::index_t K, const ck::index_t K,
ck::index_t C, const ck::index_t C,
std::array<ck::index_t, NDimSpatial> input_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths,
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
std::array<ck::index_t, NDimSpatial> output_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
std::array<ck::index_t, NDimSpatial> conv_filter_strides, const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
std::array<ck::index_t, NDimSpatial> conv_filter_dilations, const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
std::array<ck::index_t, NDimSpatial> input_left_pads, const std::array<ck::index_t, NDimSpatial>& input_left_pads,
std::array<ck::index_t, NDimSpatial> input_right_pads, const std::array<ck::index_t, NDimSpatial>& input_right_pads,
ck::index_t batch_k) const ck::index_t batch_k)
{ {
using namespace ck; using namespace ck;
...@@ -784,19 +784,19 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -784,19 +784,19 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
Argument(const InDataType* p_in_grid, Argument(const InDataType* p_in_grid,
WeiDataType* p_wei_grid, WeiDataType* p_wei_grid,
const OutDataType* p_out_grid, const OutDataType* p_out_grid,
ck::index_t G, const ck::index_t G,
ck::index_t N, const ck::index_t N,
ck::index_t K, const ck::index_t K,
ck::index_t C, const ck::index_t C,
std::array<ck::index_t, NDimSpatial>& input_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths,
std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
std::array<ck::index_t, NDimSpatial>& output_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
std::array<ck::index_t, NDimSpatial + 3>& /*input_strides*/, const std::array<ck::index_t, NDimSpatial + 3>& /*input_strides*/,
std::array<ck::index_t, NDimSpatial + 3>& /*output_strides*/, const std::array<ck::index_t, NDimSpatial + 3>& /*output_strides*/,
std::array<ck::index_t, NDimSpatial>& conv_filter_strides, const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
std::array<ck::index_t, NDimSpatial>& conv_filter_dilations, const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
std::array<ck::index_t, NDimSpatial>& input_left_pads, const std::array<ck::index_t, NDimSpatial>& input_left_pads,
std::array<ck::index_t, NDimSpatial>& input_right_pads, const std::array<ck::index_t, NDimSpatial>& input_right_pads,
InElementwiseOperation in_element_op, InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op, WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op, OutElementwiseOperation out_element_op,
...@@ -899,18 +899,18 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -899,18 +899,18 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
InElementwiseOperation c_element_op_; InElementwiseOperation c_element_op_;
// for checking IsSupportedArgument() // for checking IsSupportedArgument()
index_t Conv_G_; const index_t Conv_G_;
index_t Conv_N_; const index_t Conv_N_;
index_t Conv_K_; const index_t Conv_K_;
index_t Conv_C_; const index_t Conv_C_;
std::array<ck::index_t, NDimSpatial> input_spatial_lengths_; const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths_;
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths_; const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths_;
std::array<ck::index_t, NDimSpatial> output_spatial_lengths_; const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths_;
std::array<ck::index_t, NDimSpatial> conv_filter_strides_; const std::array<ck::index_t, NDimSpatial>& conv_filter_strides_;
std::array<ck::index_t, NDimSpatial> conv_filter_dilations_; const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations_;
std::array<ck::index_t, NDimSpatial> input_left_pads_; const std::array<ck::index_t, NDimSpatial>& input_left_pads_;
std::array<ck::index_t, NDimSpatial> input_right_pads_; const std::array<ck::index_t, NDimSpatial>& input_right_pads_;
index_t k_batch_; index_t k_batch_;
}; };
...@@ -1113,19 +1113,19 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -1113,19 +1113,19 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
static auto MakeArgument(const InDataType* p_in_grid, static auto MakeArgument(const InDataType* p_in_grid,
WeiDataType* p_wei_grid, WeiDataType* p_wei_grid,
const OutDataType* p_out_grid, const OutDataType* p_out_grid,
ck::index_t G, const ck::index_t G,
ck::index_t N, const ck::index_t N,
ck::index_t K, const ck::index_t K,
ck::index_t C, const ck::index_t C,
std::array<ck::index_t, NDimSpatial>& input_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths,
std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
std::array<ck::index_t, NDimSpatial>& output_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
std::array<ck::index_t, NDimSpatial + 3>& input_strides, const std::array<ck::index_t, NDimSpatial + 3>& input_strides,
std::array<ck::index_t, NDimSpatial + 3>& output_strides, const std::array<ck::index_t, NDimSpatial + 3>& output_strides,
std::array<ck::index_t, NDimSpatial>& conv_filter_strides, const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
std::array<ck::index_t, NDimSpatial>& conv_filter_dilations, const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
std::array<ck::index_t, NDimSpatial>& input_left_pads, const std::array<ck::index_t, NDimSpatial>& input_left_pads,
std::array<ck::index_t, NDimSpatial>& input_right_pads, const std::array<ck::index_t, NDimSpatial>& input_right_pads,
InElementwiseOperation in_element_op, InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op, WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op, OutElementwiseOperation out_element_op,
...@@ -1159,19 +1159,19 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -1159,19 +1159,19 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
MakeArgumentPointer(const void* p_in_grid, MakeArgumentPointer(const void* p_in_grid,
void* p_wei_grid, void* p_wei_grid,
const void* p_out_grid, const void* p_out_grid,
ck::index_t G, const ck::index_t G,
ck::index_t N, const ck::index_t N,
ck::index_t K, const ck::index_t K,
ck::index_t C, const ck::index_t C,
std::array<ck::index_t, NDimSpatial>& input_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths,
std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
std::array<ck::index_t, NDimSpatial>& output_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
std::array<ck::index_t, NDimSpatial + 3>& input_strides, const std::array<ck::index_t, NDimSpatial + 3>& input_strides,
std::array<ck::index_t, NDimSpatial + 3>& output_strides, const std::array<ck::index_t, NDimSpatial + 3>& output_strides,
std::array<ck::index_t, NDimSpatial>& conv_filter_strides, const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
std::array<ck::index_t, NDimSpatial>& conv_filter_dilations, const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
std::array<ck::index_t, NDimSpatial>& input_left_pads, const std::array<ck::index_t, NDimSpatial>& input_left_pads,
std::array<ck::index_t, NDimSpatial>& input_right_pads, const std::array<ck::index_t, NDimSpatial>& input_right_pads,
InElementwiseOperation in_element_op, InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op, WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op, OutElementwiseOperation out_element_op,
......
...@@ -1086,21 +1086,21 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -1086,21 +1086,21 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
Argument(const InDataType* p_in_grid, Argument(const InDataType* p_in_grid,
WeiDataType* p_wei_grid, WeiDataType* p_wei_grid,
const OutDataType* p_out_grid, const OutDataType* p_out_grid,
ck::index_t G, const ck::index_t G,
ck::index_t N, const ck::index_t N,
ck::index_t K, const ck::index_t K,
ck::index_t C, const ck::index_t C,
std::array<ck::index_t, NDimSpatial>& input_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths,
std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
std::array<ck::index_t, NDimSpatial>& output_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
std::array<ck::index_t, NDimSpatial + 3>& input_strides, const std::array<ck::index_t, NDimSpatial + 3>& input_strides,
std::array<ck::index_t, NDimSpatial + 3>& output_strides, const std::array<ck::index_t, NDimSpatial + 3>& output_strides,
std::array<ck::index_t, NDimSpatial>& conv_filter_strides, const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
std::array<ck::index_t, NDimSpatial>& conv_filter_dilations, const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
std::array<ck::index_t, NDimSpatial>& input_left_pads, const std::array<ck::index_t, NDimSpatial>& input_left_pads,
std::array<ck::index_t, NDimSpatial>& input_right_pads, const std::array<ck::index_t, NDimSpatial>& input_right_pads,
ck::index_t M01, const ck::index_t M01,
ck::index_t N01, const ck::index_t N01,
InElementwiseOperation in_element_op, InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op, WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op, OutElementwiseOperation out_element_op,
...@@ -1194,16 +1194,16 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -1194,16 +1194,16 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
WeiElementwiseOperation c_element_op_; WeiElementwiseOperation c_element_op_;
// for checking IsSupportedArgument() // for checking IsSupportedArgument()
index_t Conv_G_; const index_t Conv_G_;
index_t Conv_N_; const index_t Conv_N_;
index_t Conv_K_; const index_t Conv_K_;
index_t Conv_C_; const index_t Conv_C_;
std::array<ck::index_t, NDimSpatial>& output_spatial_lengths_; const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths_;
std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths_; const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths_;
std::array<ck::index_t, NDimSpatial>& conv_filter_strides_; const std::array<ck::index_t, NDimSpatial>& conv_filter_strides_;
std::array<ck::index_t, NDimSpatial>& input_left_pads_; const std::array<ck::index_t, NDimSpatial>& input_left_pads_;
std::array<ck::index_t, NDimSpatial>& input_right_pads_; const std::array<ck::index_t, NDimSpatial>& input_right_pads_;
index_t k_batch_; const index_t k_batch_;
}; };
// Invoker // Invoker
...@@ -1390,23 +1390,23 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -1390,23 +1390,23 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
static auto MakeArgument(const InDataType* p_in_grid, static auto MakeArgument(const InDataType* p_in_grid,
WeiDataType* p_wei_grid, WeiDataType* p_wei_grid,
const OutDataType* p_out_grid, const OutDataType* p_out_grid,
ck::index_t G, const ck::index_t G,
ck::index_t N, const ck::index_t N,
ck::index_t K, const ck::index_t K,
ck::index_t C, const ck::index_t C,
std::array<ck::index_t, NDimSpatial>& input_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths,
std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
std::array<ck::index_t, NDimSpatial>& output_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
std::array<ck::index_t, NDimSpatial + 3>& input_strides, const std::array<ck::index_t, NDimSpatial + 3>& input_strides,
std::array<ck::index_t, NDimSpatial + 3>& output_strides, const std::array<ck::index_t, NDimSpatial + 3>& output_strides,
std::array<ck::index_t, NDimSpatial>& conv_filter_strides, const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
std::array<ck::index_t, NDimSpatial>& conv_filter_dilations, const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
std::array<ck::index_t, NDimSpatial>& input_left_pads, const std::array<ck::index_t, NDimSpatial>& input_left_pads,
std::array<ck::index_t, NDimSpatial>& input_right_pads, const std::array<ck::index_t, NDimSpatial>& input_right_pads,
InElementwiseOperation in_element_op, InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op, WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op, OutElementwiseOperation out_element_op,
ck::index_t split_k) const ck::index_t split_k)
{ {
return Argument{p_in_grid, return Argument{p_in_grid,
p_wei_grid, p_wei_grid,
...@@ -1438,23 +1438,23 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -1438,23 +1438,23 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
MakeArgumentPointer(const void* p_in_grid, MakeArgumentPointer(const void* p_in_grid,
void* p_wei_grid, void* p_wei_grid,
const void* p_out_grid, const void* p_out_grid,
ck::index_t G, const ck::index_t G,
ck::index_t N, const ck::index_t N,
ck::index_t K, const ck::index_t K,
ck::index_t C, const ck::index_t C,
std::array<ck::index_t, NDimSpatial>& input_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths,
std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
std::array<ck::index_t, NDimSpatial>& output_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
std::array<ck::index_t, NDimSpatial + 3>& input_strides, const std::array<ck::index_t, NDimSpatial + 3>& input_strides,
std::array<ck::index_t, NDimSpatial + 3>& output_strides, const std::array<ck::index_t, NDimSpatial + 3>& output_strides,
std::array<ck::index_t, NDimSpatial>& conv_filter_strides, const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
std::array<ck::index_t, NDimSpatial>& conv_filter_dilations, const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
std::array<ck::index_t, NDimSpatial>& input_left_pads, const std::array<ck::index_t, NDimSpatial>& input_left_pads,
std::array<ck::index_t, NDimSpatial>& input_right_pads, const std::array<ck::index_t, NDimSpatial>& input_right_pads,
InElementwiseOperation in_element_op, InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op, WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op, OutElementwiseOperation out_element_op,
ck::index_t split_k) override const ck::index_t split_k) override
{ {
return std::make_unique<Argument>(static_cast<const InDataType*>(p_in_grid), return std::make_unique<Argument>(static_cast<const InDataType*>(p_in_grid),
static_cast<WeiDataType*>(p_wei_grid), static_cast<WeiDataType*>(p_wei_grid),
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp"
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "profiler/profile_grouped_conv_bwd_weight_impl.hpp" #include "profiler/profile_grouped_conv_bwd_weight_impl.hpp"
...@@ -23,11 +24,11 @@ class TestGroupedConvndBwdWeight : public ::testing::Test ...@@ -23,11 +24,11 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
using InLayout = std::tuple_element_t<3, Tuple>; using InLayout = std::tuple_element_t<3, Tuple>;
using WeiLayout = std::tuple_element_t<4, Tuple>; using WeiLayout = std::tuple_element_t<4, Tuple>;
using OutLayout = std::tuple_element_t<5, Tuple>; using OutLayout = std::tuple_element_t<5, Tuple>;
using NDimSpatial = std::tuple_element_t<6, Tuple>;
std::vector<ck::utils::conv::ConvParam> conv_params; std::vector<ck::utils::conv::ConvParam> conv_params;
ck::index_t split_k{2}; ck::index_t split_k{2};
template <ck::index_t NDimSpatial>
void Run() void Run()
{ {
EXPECT_FALSE(conv_params.empty()); EXPECT_FALSE(conv_params.empty());
...@@ -35,7 +36,7 @@ class TestGroupedConvndBwdWeight : public ::testing::Test ...@@ -35,7 +36,7 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
for(auto& param : conv_params) for(auto& param : conv_params)
{ {
pass = pass && ck::profiler::profile_grouped_conv_bwd_weight_impl<NDimSpatial, pass = pass && ck::profiler::profile_grouped_conv_bwd_weight_impl<NDimSpatial{},
InLayout, InLayout,
WeiLayout, WeiLayout,
OutLayout, OutLayout,
...@@ -70,21 +71,21 @@ class TestGroupedConvndBwdWeight3d : public TestGroupedConvndBwdWeight<Tuple> ...@@ -70,21 +71,21 @@ class TestGroupedConvndBwdWeight3d : public TestGroupedConvndBwdWeight<Tuple>
using namespace ck::tensor_layout::convolution; using namespace ck::tensor_layout::convolution;
using KernelTypes1d = using KernelTypes1d = ::testing::Types<
::testing::Types<std::tuple<float, float, float, GNWC, GKXC, GNWK>, std::tuple<float, float, float, GNWC, GKXC, GNWK, ck::Number<1>>,
std::tuple<ck::half_t, ck::half_t, ck::half_t, GNWC, GKXC, GNWK>, std::tuple<ck::half_t, ck::half_t, ck::half_t, GNWC, GKXC, GNWK, ck::Number<1>>,
std::tuple<ck::bhalf_t, float, ck::bhalf_t, GNWC, GKXC, GNWK>>; std::tuple<ck::bhalf_t, float, ck::bhalf_t, GNWC, GKXC, GNWK, ck::Number<1>>>;
using KernelTypes2d = using KernelTypes2d = ::testing::Types<
::testing::Types<std::tuple<float, float, float, GNHWC, GKYXC, GNHWK>, std::tuple<float, float, float, GNHWC, GKYXC, GNHWK, ck::Number<2>>,
std::tuple<ck::half_t, ck::half_t, ck::half_t, GNHWC, GKYXC, GNHWK>, std::tuple<ck::half_t, ck::half_t, ck::half_t, GNHWC, GKYXC, GNHWK, ck::Number<2>>,
std::tuple<ck::bhalf_t, float, ck::bhalf_t, GNHWC, GKYXC, GNHWK>, std::tuple<ck::bhalf_t, float, ck::bhalf_t, GNHWC, GKYXC, GNHWK, ck::Number<2>>,
std::tuple<float, float, float, NHWGC, GKYXC, NHWGK>, std::tuple<float, float, float, NHWGC, GKYXC, NHWGK, ck::Number<2>>,
std::tuple<ck::half_t, ck::half_t, ck::half_t, NHWGC, GKYXC, NHWGK>, std::tuple<ck::half_t, ck::half_t, ck::half_t, NHWGC, GKYXC, NHWGK, ck::Number<2>>,
std::tuple<ck::bhalf_t, float, ck::bhalf_t, NHWGC, GKYXC, NHWGK>>; std::tuple<ck::bhalf_t, float, ck::bhalf_t, NHWGC, GKYXC, NHWGK, ck::Number<2>>>;
using KernelTypes3d = using KernelTypes3d = ::testing::Types<
::testing::Types<std::tuple<float, float, float, GNDHWC, GKZYXC, GNDHWK>, std::tuple<float, float, float, GNDHWC, GKZYXC, GNDHWK, ck::Number<3>>,
std::tuple<ck::half_t, ck::half_t, ck::half_t, GNDHWC, GKZYXC, GNDHWK>, std::tuple<ck::half_t, ck::half_t, ck::half_t, GNDHWC, GKZYXC, GNDHWK, ck::Number<3>>,
std::tuple<ck::bhalf_t, float, ck::bhalf_t, GNDHWC, GKZYXC, GNDHWK>>; std::tuple<ck::bhalf_t, float, ck::bhalf_t, GNDHWC, GKZYXC, GNDHWK, ck::Number<3>>>;
TYPED_TEST_SUITE(TestGroupedConvndBwdWeight1d, KernelTypes1d); TYPED_TEST_SUITE(TestGroupedConvndBwdWeight1d, KernelTypes1d);
TYPED_TEST_SUITE(TestGroupedConvndBwdWeight2d, KernelTypes2d); TYPED_TEST_SUITE(TestGroupedConvndBwdWeight2d, KernelTypes2d);
...@@ -96,7 +97,7 @@ TYPED_TEST(TestGroupedConvndBwdWeight1d, Test1D) ...@@ -96,7 +97,7 @@ TYPED_TEST(TestGroupedConvndBwdWeight1d, Test1D)
this->conv_params.push_back({1, 2, 128, 128, 256, {1}, {14}, {2}, {1}, {0}, {0}}); this->conv_params.push_back({1, 2, 128, 128, 256, {1}, {14}, {2}, {1}, {0}, {0}});
this->conv_params.push_back({1, 2, 32, 128, 256, {3}, {28}, {1}, {1}, {1}, {1}}); this->conv_params.push_back({1, 2, 32, 128, 256, {3}, {28}, {1}, {1}, {1}, {1}});
this->conv_params.push_back({1, 2, 128, 128, 256, {1}, {3}, {1}, {1}, {0}, {0}}); this->conv_params.push_back({1, 2, 128, 128, 256, {1}, {3}, {1}, {1}, {0}, {0}});
this->template Run<1>(); this->Run();
} }
TYPED_TEST(TestGroupedConvndBwdWeight2d, Test2D) TYPED_TEST(TestGroupedConvndBwdWeight2d, Test2D)
...@@ -108,7 +109,7 @@ TYPED_TEST(TestGroupedConvndBwdWeight2d, Test2D) ...@@ -108,7 +109,7 @@ TYPED_TEST(TestGroupedConvndBwdWeight2d, Test2D)
{2, 2, 4, 128, 256, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); {2, 2, 4, 128, 256, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
this->conv_params.push_back( this->conv_params.push_back(
{2, 2, 128, 128, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}); {2, 2, 128, 128, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});
this->template Run<2>(); this->Run();
} }
TYPED_TEST(TestGroupedConvndBwdWeight3d, Test3D) TYPED_TEST(TestGroupedConvndBwdWeight3d, Test3D)
...@@ -120,5 +121,5 @@ TYPED_TEST(TestGroupedConvndBwdWeight3d, Test3D) ...@@ -120,5 +121,5 @@ TYPED_TEST(TestGroupedConvndBwdWeight3d, Test3D)
{3, 2, 2, 128, 256, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); {3, 2, 2, 128, 256, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
this->conv_params.push_back( this->conv_params.push_back(
{3, 2, 32, 128, 256, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); {3, 2, 32, 128, 256, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
this->template Run<3>(); this->Run();
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment