Commit 822a1110 authored by Bartlomiej Kocot's avatar Bartlomiej Kocot
Browse files

Add M and N padding

parent 5112a51e
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" #include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
...@@ -269,14 +270,10 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa ...@@ -269,14 +270,10 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
const auto OutWStride = e_g_n_k_wos_strides[spatial_offset]; const auto OutWStride = e_g_n_k_wos_strides[spatial_offset];
const index_t GemmKTotal = N * Wo; const index_t GemmKTotal = N * Wo;
const index_t GemmM = K;
const index_t GemmN = C * X;
const index_t GemmKBatch = batch_k; const index_t GemmKBatch = batch_k;
const index_t GemmK0 = const index_t GemmK0 =
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
K0PerBlock; K0PerBlock;
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
if constexpr(ConvBackwardWeightSpecialization == if constexpr(ConvBackwardWeightSpecialization ==
ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
...@@ -285,17 +282,17 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa ...@@ -285,17 +282,17 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
const auto out_gemmktotal_gemmm_grid_desc = make_naive_tensor_descriptor( const auto out_gemmktotal_gemmm_grid_desc = make_naive_tensor_descriptor(
make_tuple(N * Wo, K), make_tuple(OutWStride, OutKStride)); make_tuple(N * Wo, K), make_tuple(OutWStride, OutKStride));
const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( const auto out_gemmkpad_gemmmpad_grid_desc =
out_gemmktotal_gemmm_grid_desc, ck::tensor_operation::device::PadTensorDescriptor(
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), out_gemmktotal_gemmm_grid_desc,
make_pass_through_transform(GemmM)), make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, MPerBlock),
make_tuple(Sequence<0>{}, Sequence<1>{}), Sequence<true, true>{});
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmm_grid_desc, out_gemmkpad_gemmmpad_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), make_tuple(
make_pass_through_transform(GemmM)), make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(out_gemmkpad_gemmmpad_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
...@@ -303,17 +300,17 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa ...@@ -303,17 +300,17 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
const auto in_gemmktotal_gemmn_grid_desc = make_naive_tensor_descriptor( const auto in_gemmktotal_gemmn_grid_desc = make_naive_tensor_descriptor(
make_tuple(N * Wi, C), make_tuple(InWStride, InCStride)); make_tuple(N * Wi, C), make_tuple(InWStride, InCStride));
const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( const auto in_gemmkpad_gemmnpad_grid_desc =
in_gemmktotal_gemmn_grid_desc, ck::tensor_operation::device::PadTensorDescriptor(
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), in_gemmktotal_gemmn_grid_desc,
make_pass_through_transform(GemmM)), make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, NPerBlock),
make_tuple(Sequence<0>{}, Sequence<1>{}), Sequence<true, true>{});
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc, in_gemmkpad_gemmnpad_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), make_tuple(
make_pass_through_transform(GemmM)), make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(in_gemmkpad_gemmnpad_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
...@@ -321,9 +318,14 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa ...@@ -321,9 +318,14 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
const auto wei_gemmm_gemmn_grid_desc = make_naive_tensor_descriptor( const auto wei_gemmm_gemmn_grid_desc = make_naive_tensor_descriptor(
make_tuple(K, X * C), make_tuple(WeiKStride, WeiCStride)); make_tuple(K, X * C), make_tuple(WeiKStride, WeiCStride));
const auto wei_gemmmpad_gemmnpad_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(wei_gemmm_gemmn_grid_desc,
make_tuple(MPerBlock, NPerBlock),
Sequence<true, true>{});
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc); wei_gemmmpad_gemmnpad_grid_desc);
} }
else else
{ {
...@@ -333,17 +335,17 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa ...@@ -333,17 +335,17 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
make_tuple(N, Wi, C), make_tuple(InNStride, InWStride, InCStride)); make_tuple(N, Wi, C), make_tuple(InNStride, InWStride, InCStride));
// A: output tensor // A: output tensor
const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( const auto out_gemmkpad_gemmmpad_grid_desc =
out_gemmktotal_gemmm_grid_desc, ck::tensor_operation::device::PadTensorDescriptor(
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), out_gemmktotal_gemmm_grid_desc,
make_pass_through_transform(GemmM)), make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, MPerBlock),
make_tuple(Sequence<0>{}, Sequence<1>{}), Sequence<true, true>{});
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmm_grid_desc, out_gemmkpad_gemmmpad_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), make_tuple(
make_pass_through_transform(GemmM)), make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(out_gemmkpad_gemmmpad_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
...@@ -372,17 +374,17 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa ...@@ -372,17 +374,17 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
make_tuple(Sequence<1, 3>{}, Sequence<0, 2>{}), make_tuple(Sequence<1, 3>{}, Sequence<0, 2>{}),
make_tuple(Sequence<1>{}, Sequence<0>{})); make_tuple(Sequence<1>{}, Sequence<0>{}));
const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( const auto in_gemmkpad_gemmnpad_grid_desc =
in_gemmktotal_gemmn_grid_desc, ck::tensor_operation::device::PadTensorDescriptor(
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), in_gemmktotal_gemmn_grid_desc,
make_pass_through_transform(GemmN)), make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, NPerBlock),
make_tuple(Sequence<0>{}, Sequence<1>{}), Sequence<true, true>{});
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc, in_gemmkpad_gemmnpad_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), make_tuple(
make_pass_through_transform(GemmN)), make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(in_gemmkpad_gemmnpad_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
...@@ -390,9 +392,14 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa ...@@ -390,9 +392,14 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
const auto wei_gemmm_gemmn_grid_desc = make_naive_tensor_descriptor( const auto wei_gemmm_gemmn_grid_desc = make_naive_tensor_descriptor(
make_tuple(K, X * C), make_tuple(WeiKStride, WeiCStride)); make_tuple(K, X * C), make_tuple(WeiKStride, WeiCStride));
const auto wei_gemmmpad_gemmnpad_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(wei_gemmm_gemmn_grid_desc,
make_tuple(MPerBlock, NPerBlock),
Sequence<true, true>{});
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc); wei_gemmmpad_gemmnpad_grid_desc);
} }
} // function end } // function end
...@@ -441,14 +448,10 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa ...@@ -441,14 +448,10 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
const auto OutWStride = e_g_n_k_wos_strides[spatial_offset + I1]; const auto OutWStride = e_g_n_k_wos_strides[spatial_offset + I1];
const index_t GemmKTotal = N * Ho * Wo; const index_t GemmKTotal = N * Ho * Wo;
const index_t GemmM = K;
const index_t GemmN = C * X * Y;
const index_t GemmKBatch = batch_k; const index_t GemmKBatch = batch_k;
const index_t GemmK0 = const index_t GemmK0 =
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
K0PerBlock; K0PerBlock;
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
if constexpr(ConvBackwardWeightSpecialization == if constexpr(ConvBackwardWeightSpecialization ==
ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
...@@ -457,17 +460,17 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa ...@@ -457,17 +460,17 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
const auto out_gemmktotal_gemmm_grid_desc = make_naive_tensor_descriptor( const auto out_gemmktotal_gemmm_grid_desc = make_naive_tensor_descriptor(
make_tuple(N * Ho * Wo, K), make_tuple(OutWStride, OutKStride)); make_tuple(N * Ho * Wo, K), make_tuple(OutWStride, OutKStride));
const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( const auto out_gemmkpad_gemmmpad_grid_desc =
out_gemmktotal_gemmm_grid_desc, ck::tensor_operation::device::PadTensorDescriptor(
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), out_gemmktotal_gemmm_grid_desc,
make_pass_through_transform(GemmM)), make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, MPerBlock),
make_tuple(Sequence<0>{}, Sequence<1>{}), Sequence<true, true>{});
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmm_grid_desc, out_gemmkpad_gemmmpad_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), make_tuple(
make_pass_through_transform(GemmM)), make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(out_gemmkpad_gemmmpad_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
...@@ -475,17 +478,17 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa ...@@ -475,17 +478,17 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
const auto in_gemmktotal_gemmn_grid_desc = make_naive_tensor_descriptor( const auto in_gemmktotal_gemmn_grid_desc = make_naive_tensor_descriptor(
make_tuple(N * Hi * Wi, C), make_tuple(InWStride, InCStride)); make_tuple(N * Hi * Wi, C), make_tuple(InWStride, InCStride));
const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( const auto in_gemmkpad_gemmnpad_grid_desc =
in_gemmktotal_gemmn_grid_desc, ck::tensor_operation::device::PadTensorDescriptor(
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), in_gemmktotal_gemmn_grid_desc,
make_pass_through_transform(GemmM)), make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, NPerBlock),
make_tuple(Sequence<0>{}, Sequence<1>{}), Sequence<true, true>{});
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc, in_gemmkpad_gemmnpad_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), make_tuple(
make_pass_through_transform(GemmM)), make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(in_gemmkpad_gemmnpad_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
...@@ -493,9 +496,14 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa ...@@ -493,9 +496,14 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
const auto wei_gemmm_gemmn_grid_desc = make_naive_tensor_descriptor( const auto wei_gemmm_gemmn_grid_desc = make_naive_tensor_descriptor(
make_tuple(K, Y * X * C), make_tuple(WeiKStride, WeiCStride)); make_tuple(K, Y * X * C), make_tuple(WeiKStride, WeiCStride));
const auto wei_gemmmpad_gemmnpad_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(wei_gemmm_gemmn_grid_desc,
make_tuple(MPerBlock, NPerBlock),
Sequence<true, true>{});
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc); wei_gemmmpad_gemmnpad_grid_desc);
} }
else else
{ {
...@@ -505,17 +513,17 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa ...@@ -505,17 +513,17 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
make_tuple(N, Hi, Wi, C), make_tuple(InNStride, InHStride, InWStride, InCStride)); make_tuple(N, Hi, Wi, C), make_tuple(InNStride, InHStride, InWStride, InCStride));
// A: output tensor // A: output tensor
const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( const auto out_gemmkpad_gemmmpad_grid_desc =
out_gemmktotal_gemmm_grid_desc, ck::tensor_operation::device::PadTensorDescriptor(
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), out_gemmktotal_gemmm_grid_desc,
make_pass_through_transform(GemmM)), make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, MPerBlock),
make_tuple(Sequence<0>{}, Sequence<1>{}), Sequence<true, true>{});
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmm_grid_desc, out_gemmkpad_gemmmpad_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), make_tuple(
make_pass_through_transform(GemmM)), make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(out_gemmkpad_gemmmpad_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
...@@ -546,17 +554,17 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa ...@@ -546,17 +554,17 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
make_tuple(Sequence<1>{}, Sequence<0>{})); make_tuple(Sequence<1>{}, Sequence<0>{}));
const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( const auto in_gemmkpad_gemmnpad_grid_desc =
in_gemmktotal_gemmn_grid_desc, ck::tensor_operation::device::PadTensorDescriptor(
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), in_gemmktotal_gemmn_grid_desc,
make_pass_through_transform(GemmN)), make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, NPerBlock),
make_tuple(Sequence<0>{}, Sequence<1>{}), Sequence<true, true>{});
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc, in_gemmkpad_gemmnpad_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), make_tuple(
make_pass_through_transform(GemmN)), make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(in_gemmkpad_gemmnpad_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
...@@ -564,9 +572,14 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa ...@@ -564,9 +572,14 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
const auto wei_gemmm_gemmn_grid_desc = make_naive_tensor_descriptor( const auto wei_gemmm_gemmn_grid_desc = make_naive_tensor_descriptor(
make_tuple(K, Y * X * C), make_tuple(WeiKStride, WeiCStride)); make_tuple(K, Y * X * C), make_tuple(WeiKStride, WeiCStride));
const auto wei_gemmmpad_gemmnpad_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(wei_gemmm_gemmn_grid_desc,
make_tuple(MPerBlock, NPerBlock),
Sequence<true, true>{});
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc); wei_gemmmpad_gemmnpad_grid_desc);
} }
} // function end } // function end
...@@ -624,14 +637,10 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa ...@@ -624,14 +637,10 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
const auto OutWStride = e_g_n_k_wos_strides[spatial_offset + I2]; const auto OutWStride = e_g_n_k_wos_strides[spatial_offset + I2];
const index_t GemmKTotal = N * Do * Ho * Wo; const index_t GemmKTotal = N * Do * Ho * Wo;
const index_t GemmM = K;
const index_t GemmN = C * Z * X * Y;
const index_t GemmKBatch = batch_k; const index_t GemmKBatch = batch_k;
const index_t GemmK0 = const index_t GemmK0 =
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
K0PerBlock; K0PerBlock;
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
if constexpr(ConvBackwardWeightSpecialization == if constexpr(ConvBackwardWeightSpecialization ==
ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
...@@ -640,17 +649,17 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa ...@@ -640,17 +649,17 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
const auto out_gemmktotal_gemmm_grid_desc = make_naive_tensor_descriptor( const auto out_gemmktotal_gemmm_grid_desc = make_naive_tensor_descriptor(
make_tuple(N * Do * Ho * Wo, K), make_tuple(OutWStride, OutKStride)); make_tuple(N * Do * Ho * Wo, K), make_tuple(OutWStride, OutKStride));
const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( const auto out_gemmkpad_gemmmpad_grid_desc =
out_gemmktotal_gemmm_grid_desc, ck::tensor_operation::device::PadTensorDescriptor(
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), out_gemmktotal_gemmm_grid_desc,
make_pass_through_transform(GemmM)), make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, MPerBlock),
make_tuple(Sequence<0>{}, Sequence<1>{}), Sequence<true, true>{});
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmm_grid_desc, out_gemmkpad_gemmmpad_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), make_tuple(
make_pass_through_transform(GemmM)), make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(out_gemmkpad_gemmmpad_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
...@@ -658,17 +667,17 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa ...@@ -658,17 +667,17 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
const auto in_gemmktotal_gemmn_grid_desc = make_naive_tensor_descriptor( const auto in_gemmktotal_gemmn_grid_desc = make_naive_tensor_descriptor(
make_tuple(N * Di * Hi * Wi, C), make_tuple(InWStride, InCStride)); make_tuple(N * Di * Hi * Wi, C), make_tuple(InWStride, InCStride));
const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( const auto in_gemmkpad_gemmnpad_grid_desc =
in_gemmktotal_gemmn_grid_desc, ck::tensor_operation::device::PadTensorDescriptor(
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), in_gemmktotal_gemmn_grid_desc,
make_pass_through_transform(GemmM)), make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, NPerBlock),
make_tuple(Sequence<0>{}, Sequence<1>{}), Sequence<true, true>{});
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc, in_gemmkpad_gemmnpad_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), make_tuple(
make_pass_through_transform(GemmM)), make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(in_gemmkpad_gemmnpad_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
...@@ -676,9 +685,14 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa ...@@ -676,9 +685,14 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
const auto wei_gemmm_gemmn_grid_desc = make_naive_tensor_descriptor( const auto wei_gemmm_gemmn_grid_desc = make_naive_tensor_descriptor(
make_tuple(K, Z * Y * X * C), make_tuple(WeiKStride, WeiCStride)); make_tuple(K, Z * Y * X * C), make_tuple(WeiKStride, WeiCStride));
const auto wei_gemmmpad_gemmnpad_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(wei_gemmm_gemmn_grid_desc,
make_tuple(MPerBlock, NPerBlock),
Sequence<true, true>{});
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc); wei_gemmmpad_gemmnpad_grid_desc);
} }
else else
{ {
...@@ -689,17 +703,17 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa ...@@ -689,17 +703,17 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
make_tuple(InNStride, InDStride, InHStride, InWStride, InCStride)); make_tuple(InNStride, InDStride, InHStride, InWStride, InCStride));
// A: output tensor // A: output tensor
const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( const auto out_gemmkpad_gemmmpad_grid_desc =
out_gemmktotal_gemmm_grid_desc, ck::tensor_operation::device::PadTensorDescriptor(
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), out_gemmktotal_gemmm_grid_desc,
make_pass_through_transform(GemmM)), make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, MPerBlock),
make_tuple(Sequence<0>{}, Sequence<1>{}), Sequence<true, true>{});
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmm_grid_desc, out_gemmkpad_gemmmpad_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), make_tuple(
make_pass_through_transform(GemmM)), make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(out_gemmkpad_gemmmpad_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
...@@ -739,17 +753,17 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa ...@@ -739,17 +753,17 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
make_tuple(Sequence<1, 3, 5, 7>{}, Sequence<0, 2, 4, 6>{}), make_tuple(Sequence<1, 3, 5, 7>{}, Sequence<0, 2, 4, 6>{}),
make_tuple(Sequence<1>{}, Sequence<0>{})); make_tuple(Sequence<1>{}, Sequence<0>{}));
const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( const auto in_gemmkpad_gemmnpad_grid_desc =
in_gemmktotal_gemmn_grid_desc, ck::tensor_operation::device::PadTensorDescriptor(
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), in_gemmktotal_gemmn_grid_desc,
make_pass_through_transform(GemmN)), make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, NPerBlock),
make_tuple(Sequence<0>{}, Sequence<1>{}), Sequence<true, true>{});
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc, in_gemmkpad_gemmnpad_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), make_tuple(
make_pass_through_transform(GemmN)), make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(in_gemmkpad_gemmnpad_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
...@@ -757,9 +771,14 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa ...@@ -757,9 +771,14 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
const auto wei_gemmm_gemmn_grid_desc = make_naive_tensor_descriptor( const auto wei_gemmm_gemmn_grid_desc = make_naive_tensor_descriptor(
make_tuple(K, Z * Y * X * C), make_tuple(WeiKStride, WeiCStride)); make_tuple(K, Z * Y * X * C), make_tuple(WeiKStride, WeiCStride));
const auto wei_gemmmpad_gemmnpad_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(wei_gemmm_gemmn_grid_desc,
make_tuple(MPerBlock, NPerBlock),
Sequence<true, true>{});
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc); wei_gemmmpad_gemmnpad_grid_desc);
} }
} // function end } // function end
......
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
#include "profiler/profile_grouped_conv_bwd_weight_impl.hpp" #include "profiler/profile_grouped_conv_bwd_weight_impl.hpp"
using namespace ck::tensor_layout::convolution;
template <typename Tuple> template <typename Tuple>
class TestGroupedConvndBwdWeight : public ::testing::Test class TestGroupedConvndBwdWeight : public ::testing::Test
{ {
...@@ -35,7 +37,17 @@ class TestGroupedConvndBwdWeight : public ::testing::Test ...@@ -35,7 +37,17 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
// dl kernel is only supported for split_k=1 // dl kernel is only supported for split_k=1
if constexpr(std::is_same_v<InDataType, ck::half_t>) if constexpr(std::is_same_v<InDataType, ck::half_t>)
{ {
if(split_k == 1 && (params.K_ == 1 || params.C_ == 1)) if(split_k != 1 && (params.K_ % 2 != 0 || params.C_ % 2 != 0))
{
return true;
}
}
// 1d nhwgc is only supported by dl kernel
// dl kernel is only supported for split_k=1
if constexpr(std::is_same_v<InLayout, NWGC> && std::is_same_v<OutLayout, NWGK>)
{
if(split_k != 1)
{ {
return true; return true;
} }
...@@ -90,8 +102,6 @@ class TestGroupedConvndBwdWeight3d : public TestGroupedConvndBwdWeight<Tuple> ...@@ -90,8 +102,6 @@ class TestGroupedConvndBwdWeight3d : public TestGroupedConvndBwdWeight<Tuple>
{ {
}; };
using namespace ck::tensor_layout::convolution;
using KernelTypes1d = ::testing::Types< using KernelTypes1d = ::testing::Types<
std::tuple<float, float, float, GNWC, GKXC, GNWK, ck::Number<1>>, std::tuple<float, float, float, GNWC, GKXC, GNWK, ck::Number<1>>,
std::tuple<ck::half_t, ck::half_t, ck::half_t, GNWC, GKXC, GNWK, ck::Number<1>>, std::tuple<ck::half_t, ck::half_t, ck::half_t, GNWC, GKXC, GNWK, ck::Number<1>>,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment