Commit a116cb0b authored by Chao Liu's avatar Chao Liu
Browse files

testing wrw

parent 75a5a175
...@@ -23,7 +23,7 @@ template <typename... In, ...@@ -23,7 +23,7 @@ template <typename... In,
index_t GemmK1Value, index_t GemmK1Value,
typename GemmKBatchType> typename GemmKBatchType>
__host__ __device__ constexpr auto __host__ __device__ constexpr auto
transform_backward_weight_convolution_into_gemm_v4r4r5_nhwc_kyxc_nhwk_pad( transform_backward_weight_convolution_into_gemm_v4r4r5_nhwc_kyxc_nhwk(
const TensorDescriptor<In...>& in_n_hi_wi_c_grid_desc, const TensorDescriptor<In...>& in_n_hi_wi_c_grid_desc,
const TensorDescriptor<Wei...>& wei_k_y_x_c_grid_desc, const TensorDescriptor<Wei...>& wei_k_y_x_c_grid_desc,
const TensorDescriptor<Out...>& out_n_ho_wo_k_grid_desc, const TensorDescriptor<Out...>& out_n_ho_wo_k_grid_desc,
......
...@@ -52,7 +52,7 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_ ...@@ -52,7 +52,7 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_
const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths); const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths);
#if 0 #if 0
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32 // [M, N, K0, K1] = [256, 128, 4, 4], C 128, for fp32
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 256; constexpr index_t GemmMPerBlock = 256;
...@@ -80,7 +80,7 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_ ...@@ -80,7 +80,7 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#elif 0 #elif 0
// [M, N, K0, K1] = [128, 256, 4, 4] for fp32 // [M, N, K0, K1] = [128, 256, 4, 4], C 128, for fp32
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128; constexpr index_t GemmMPerBlock = 128;
...@@ -107,8 +107,8 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_ ...@@ -107,8 +107,8 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2; constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#elif 1 #elif 0
// [M, N, K0, K1] = [128, 128, 4, 4] for fp32 // [M, N, K0, K1] = [128, 128, 4, 4], C 64, for fp32
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128; constexpr index_t GemmMPerBlock = 128;
...@@ -134,19 +134,159 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_ ...@@ -134,19 +134,159 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4; constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2; constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#elif 0
// [M, N, K0, K1] = [256, 128, 4, 8], C 128, for fp16
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 256;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerXDL = 32;
constexpr index_t GemmNPerXDL = 32;
constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 16, 2>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 16, 4>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 8;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 8, 2>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 16, 4>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 8;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#elif 0
// [M, N, K0, K1] = [128, 128, 4, 8], C 64, for fp16
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerXDL = 32;
constexpr index_t GemmNPerXDL = 32;
constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 8, 2>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 16, 4>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 8;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 8, 2>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 16, 4>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 8;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#elif 0
// [M, N, K0, K1] = [128, 64, 4, 8], C 64, for fp16
constexpr index_t BlockSize = 128;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 64;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerXDL = 32;
constexpr index_t GemmNPerXDL = 32;
constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 16, 2>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8, 4>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 8;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 8, 2>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8, 4>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 8;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#elif 1
// [M, N, K0, K1] = [64, 128, 4, 8], C 64, for fp16
constexpr index_t BlockSize = 128;
constexpr index_t GemmMPerBlock = 64;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerXDL = 32;
constexpr index_t GemmNPerXDL = 32;
constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 8, 2>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8, 4>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 8;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 16, 2>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8, 4>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 8;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#elif 1
// [M, N, K0, K1] = [64, 64, 4, 8], C 32, for fp16
constexpr index_t BlockSize = 128;
constexpr index_t GemmMPerBlock = 64;
constexpr index_t GemmNPerBlock = 64;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerXDL = 32;
constexpr index_t GemmNPerXDL = 32;
constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 1;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 8, 2>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8, 4>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 8;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 8, 2>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8, 4>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 8;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#endif #endif
const auto descs = transform_backward_weight_convolution_into_gemm_v4r4r5_nhwc_kyxc_nhwk_pad( const auto descs =
in_n_hi_wi_c_desc, transform_backward_weight_convolution_into_gemm_v4r4r5_nhwc_kyxc_nhwk(in_n_hi_wi_c_desc,
wei_k_y_x_c_desc, wei_k_y_x_c_desc,
out_n_ho_wo_k_desc, out_n_ho_wo_k_desc,
conv_strides, conv_strides,
conv_dilations, conv_dilations,
in_left_pads, in_left_pads,
in_right_pads, in_right_pads,
Number<GemmK1>{}, Number<GemmK1>{},
GemmKBatch); GemmKBatch);
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = descs[I0]; const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = descs[I0];
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = descs[I1]; const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = descs[I1];
...@@ -223,7 +363,7 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_ ...@@ -223,7 +363,7 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1, GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1,
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1, GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1,
Sequence<0, 1, 2, 3>, Sequence<0, 1, 2, 3>,
Sequence<0, 1, 3, 2>, Sequence<0, 1, 2, 3>,
2, 2,
GemmABlockTransferSrcScalarPerVector_GemmM, GemmABlockTransferSrcScalarPerVector_GemmM,
GemmABlockTransferDstScalarPerVector_GemmK1, GemmABlockTransferDstScalarPerVector_GemmK1,
......
...@@ -130,7 +130,7 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid, ...@@ -130,7 +130,7 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid,
if(!GridwiseGemm::CheckValidity(a_b_k0_m_k1_grid_desc, b_b_k0_n_k1_grid_desc, c_m_n_grid_desc)) if(!GridwiseGemm::CheckValidity(a_b_k0_m_k1_grid_desc, b_b_k0_n_k1_grid_desc, c_m_n_grid_desc))
{ {
throw std::runtime_error( throw std::runtime_error(
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting"); "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r4 has invalid setting");
} }
const auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc = const auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc =
......
...@@ -12,10 +12,12 @@ ...@@ -12,10 +12,12 @@
#include "conv_common.hpp" #include "conv_common.hpp"
#include "host_conv_bwd_weight.hpp" #include "host_conv_bwd_weight.hpp"
#include "device_tensor.hpp" #include "device_tensor.hpp"
#include "device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp" //#include "device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp"
#include "device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp" //#include "device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp"
#include "device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_atomic_nchw_kcyx_nkhw.hpp" //#include
#include "device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk.hpp" //"device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_atomic_nchw_kcyx_nkhw.hpp"
//#include
//"device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk.hpp"
#include "device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_kyxc_nhwk.hpp" #include "device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_kyxc_nhwk.hpp"
#define USE_DYNAMIC_MODE 1 #define USE_DYNAMIC_MODE 1
...@@ -27,11 +29,11 @@ ...@@ -27,11 +29,11 @@
enum ConvBackwardWeightAlgo enum ConvBackwardWeightAlgo
{ {
V4R4R2XDLNCHW, V4R4R2XDLNCHW, // 0
V4R4R4XDLNHWC, V4R4R4XDLNHWC, // 1
V4R4R2XDLATOMICNCHW, V4R4R2XDLATOMICNCHW, // 2
V4R4R4XDLATOMICNHWC, V4R4R4XDLATOMICNHWC, // 3
V4R4R5XDLATOMICNHWC, V4R4R5XDLATOMICNHWC, // 4
}; };
int main(int argc, char* argv[]) int main(int argc, char* argv[])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment