Commit c03045ce authored by Chao Liu's avatar Chao Liu
Browse files

rename

parent b2589957
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#include "device.hpp" #include "device.hpp"
#include "host_tensor.hpp" #include "host_tensor.hpp"
#include "transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp" #include "transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp"
#include "driver_dynamic_gemm_xdlops_v2r2.hpp" #include "driver_gemm_xdlops_v2r2.hpp"
template <typename TInWei, template <typename TInWei,
typename TAcc, typename TAcc,
...@@ -14,7 +14,7 @@ template <typename TInWei, ...@@ -14,7 +14,7 @@ template <typename TInWei,
typename ConvDilations, typename ConvDilations,
typename InLeftPads, typename InLeftPads,
typename InRightPads> typename InRightPads>
void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk( void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk(
const InLengths& in_n_hi_wi_c_lengths, const InLengths& in_n_hi_wi_c_lengths,
const WeiLengths& wei_k_y_x_c_lengths, const WeiLengths& wei_k_y_x_c_lengths,
const OutLengths& out_n_ho_wo_k_lengths, const OutLengths& out_n_ho_wo_k_lengths,
...@@ -44,12 +44,9 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nh ...@@ -44,12 +44,9 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nh
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data()); wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data()); out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data());
const auto in_n_hi_wi_c_desc = const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths);
make_dynamic_naive_tensor_descriptor_packed_v2(in_n_hi_wi_c_lengths); const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths);
const auto wei_k_y_x_c_desc = const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths);
make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_y_x_c_lengths);
const auto out_n_ho_wo_k_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(out_n_ho_wo_k_lengths);
#if 1 #if 1
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32 // [M, N, K0, K1] = [256, 128, 4, 4] for fp32
...@@ -155,7 +152,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nh ...@@ -155,7 +152,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nh
for(index_t i = 0; i < 5; ++i) for(index_t i = 0; i < 5; ++i)
{ {
float ave_time = driver_dynamic_gemm_xdlops_v2r2< float ave_time = driver_gemm_xdlops_v2r2<
BlockSize, BlockSize,
TInWei, TInWei,
TAcc, TAcc,
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#include "device.hpp" #include "device.hpp"
#include "host_tensor.hpp" #include "host_tensor.hpp"
#include "transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp" #include "transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp"
#include "driver_dynamic_gemm_xdlops_v2r3.hpp" #include "driver_gemm_xdlops_v2r3.hpp"
template <typename TInWei, template <typename TInWei,
typename TAcc, typename TAcc,
...@@ -14,7 +14,7 @@ template <typename TInWei, ...@@ -14,7 +14,7 @@ template <typename TInWei,
typename ConvDilations, typename ConvDilations,
typename InLeftPads, typename InLeftPads,
typename InRightPads> typename InRightPads>
void device_dynamic_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk( void device_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk(
const InLengths& in_n_hi_wi_c_lengths, const InLengths& in_n_hi_wi_c_lengths,
const WeiLengths& wei_k_y_x_c_lengths, const WeiLengths& wei_k_y_x_c_lengths,
const OutLengths& out_n_ho_wo_k_lengths, const OutLengths& out_n_ho_wo_k_lengths,
...@@ -49,12 +49,9 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nh ...@@ -49,12 +49,9 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nh
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data()); wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data()); out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data());
const auto in_n_hi_wi_c_desc = const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths);
make_dynamic_naive_tensor_descriptor_packed_v2(in_n_hi_wi_c_lengths); const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths);
const auto wei_k_y_x_c_desc = const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths);
make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_y_x_c_lengths);
const auto out_n_ho_wo_k_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(out_n_ho_wo_k_lengths);
#if 1 #if 1
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32 // [M, N, K0, K1] = [256, 128, 4, 4] for fp32
...@@ -224,7 +221,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nh ...@@ -224,7 +221,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nh
for(index_t i = 0; i < 5; ++i) for(index_t i = 0; i < 5; ++i)
{ {
float ave_time = driver_dynamic_gemm_xdlops_v2r3< float ave_time = driver_gemm_xdlops_v2r3<
BlockSize, BlockSize,
TInWei, TInWei,
TAcc, TAcc,
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#include "device.hpp" #include "device.hpp"
#include "host_tensor.hpp" #include "host_tensor.hpp"
#include "transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp" #include "transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp"
#include "driver_dynamic_gemm_xdlops_v2r3.hpp" #include "driver_gemm_xdlops_v2r3.hpp"
template <typename TInWei, template <typename TInWei,
typename TAcc, typename TAcc,
...@@ -14,7 +14,7 @@ template <typename TInWei, ...@@ -14,7 +14,7 @@ template <typename TInWei,
typename ConvDilations, typename ConvDilations,
typename InLeftPads, typename InLeftPads,
typename InRightPads> typename InRightPads>
void device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
const InLengths& in_n_hi_wi_c_lengths, const InLengths& in_n_hi_wi_c_lengths,
const WeiLengths& wei_k_y_x_c_lengths, const WeiLengths& wei_k_y_x_c_lengths,
const OutLengths& out_n_ho_wo_k_lengths, const OutLengths& out_n_ho_wo_k_lengths,
...@@ -44,12 +44,9 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh ...@@ -44,12 +44,9 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data()); wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data()); out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data());
const auto in_n_hi_wi_c_desc = const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths);
make_dynamic_naive_tensor_descriptor_packed_v2(in_n_hi_wi_c_lengths); const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths);
const auto wei_k_y_x_c_desc = const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths);
make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_y_x_c_lengths);
const auto out_n_ho_wo_k_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(out_n_ho_wo_k_lengths);
#if 0 #if 0
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32 // [M, N, K0, K1] = [256, 128, 4, 4] for fp32
...@@ -278,7 +275,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh ...@@ -278,7 +275,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh
for(index_t i = 0; i < 5; ++i) for(index_t i = 0; i < 5; ++i)
{ {
float ave_time = driver_dynamic_gemm_xdlops_v2r3< float ave_time = driver_gemm_xdlops_v2r3<
BlockSize, BlockSize,
TInWei, TInWei,
TAcc, TAcc,
......
#include <unistd.h> #include <unistd.h>
#include "device.hpp" #include "device.hpp"
#include "host_tensor.hpp" #include "host_tensor.hpp"
#include "driver_dynamic_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp" #include "driver_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp"
#include "driver_dynamic_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw_outpad.hpp" #include "driver_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw_outpad.hpp"
template <typename TInWei, template <typename TInWei,
ck::index_t InWeiVectorSize, ck::index_t InWeiVectorSize,
...@@ -15,7 +15,7 @@ template <typename TInWei, ...@@ -15,7 +15,7 @@ template <typename TInWei,
typename ConvDilations, typename ConvDilations,
typename InLeftPads, typename InLeftPads,
typename InRightPads> typename InRightPads>
void device_dynamic_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw( void device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw(
const InLengths& in_n_c_hi_wi_lengths, const InLengths& in_n_c_hi_wi_lengths,
const WeiLengths& wei_k_c_y_x_lengths, const WeiLengths& wei_k_c_y_x_lengths,
const OutLengths& out_n_k_ho_wo_lengths, const OutLengths& out_n_k_ho_wo_lengths,
...@@ -85,12 +85,10 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw( ...@@ -85,12 +85,10 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw(
in_n_c0_hi_wi_c1_device_buf.ToDevice(in_n_c0_hi_wi_c1.mData.data()); in_n_c0_hi_wi_c1_device_buf.ToDevice(in_n_c0_hi_wi_c1.mData.data());
wei_k_c0_y_x_c1_device_buf.ToDevice(wei_k_c0_y_x_c1.mData.data()); wei_k_c0_y_x_c1_device_buf.ToDevice(wei_k_c0_y_x_c1.mData.data());
const auto in_n_c0_hi_wi_desc = const auto in_n_c0_hi_wi_desc = make_naive_tensor_descriptor_packed(make_tuple(N, C0, Hi, Wi));
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, C0, Hi, Wi)); const auto wei_k_c0_y_x_desc = make_naive_tensor_descriptor_packed(make_tuple(K, C0, Y, X));
const auto wei_k_c0_y_x_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C0, Y, X));
const auto out_n_k0_ho_wo_k1_desc = const auto out_n_k0_ho_wo_k1_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K0, Ho, Wo, K1)); make_naive_tensor_descriptor_packed(make_tuple(N, K0, Ho, Wo, K1));
#if 1 #if 1
// cdata = 64, BlockSize = 64, 16x8x32x4 // cdata = 64, BlockSize = 64, 16x8x32x4
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
#include "device.hpp" #include "device.hpp"
#include "host_tensor.hpp" #include "host_tensor.hpp"
#include "transform_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp" #include "transform_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp"
#include "driver_dynamic_contraction_dlops_v1r2.hpp" #include "driver_contraction_dlops_v1r2.hpp"
template <typename TInWei, template <typename TInWei,
typename TAcc, typename TAcc,
...@@ -15,7 +15,7 @@ template <typename TInWei, ...@@ -15,7 +15,7 @@ template <typename TInWei,
typename ConvDilations, typename ConvDilations,
typename InLeftPads, typename InLeftPads,
typename InRightPads> typename InRightPads>
void device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw( void device_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw(
const InLengths& in_n_c_hi_wi_lengths, const InLengths& in_n_c_hi_wi_lengths,
const WeiLengths& wei_k_c_y_x_lengths, const WeiLengths& wei_k_c_y_x_lengths,
const OutLengths& out_n_k_ho_wo_lengths, const OutLengths& out_n_k_ho_wo_lengths,
...@@ -44,12 +44,9 @@ void device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw( ...@@ -44,12 +44,9 @@ void device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw(
wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data()); wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data());
out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data()); out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data());
const auto in_desc_n_c_hi_wi = const auto in_desc_n_c_hi_wi = make_naive_tensor_descriptor_packed(in_n_c_hi_wi_lengths);
make_dynamic_naive_tensor_descriptor_packed_v2(in_n_c_hi_wi_lengths); const auto wei_desc_k_c_y_x = make_naive_tensor_descriptor_packed(wei_k_c_y_x_lengths);
const auto wei_desc_k_c_y_x = const auto out_desc_n_k_ho_wo = make_naive_tensor_descriptor_packed(out_n_k_ho_wo_lengths);
make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_c_y_x_lengths);
const auto out_desc_n_k_ho_wo =
make_dynamic_naive_tensor_descriptor_packed_v2(out_n_k_ho_wo_lengths);
#if 0 #if 0
// [8, 1, 128, 1] * [8, 4, 32, 1] = [1, 128, 4, 32] for fp32 // [8, 1, 128, 1] * [8, 4, 32, 1] = [1, 128, 4, 32] for fp32
...@@ -180,7 +177,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw( ...@@ -180,7 +177,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw(
for(index_t i = 0; i < 5; ++i) for(index_t i = 0; i < 5; ++i)
{ {
float ave_time = driver_dynamic_contraction_dlops_v1r2< float ave_time = driver_contraction_dlops_v1r2<
BlockSize, BlockSize,
TInWei, TInWei,
TAcc, TAcc,
......
#ifndef DRIVER_DYNAMIC_CONTRACTION_DLOPS_V1R2_HPP #ifndef DRIVER_CONTRACTION_DLOPS_V1R2_HPP
#define DRIVER_DYNAMIC_CONTRACTION_DLOPS_V1R2_HPP #define DRIVER_CONTRACTION_DLOPS_V1R2_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "gridwise_dynamic_contraction_dlops_v1r2.hpp" #include "gridwise_contraction_dlops_v1r2.hpp"
template <ck::index_t BlockSize, template <ck::index_t BlockSize,
typename FloatAB, typename FloatAB,
...@@ -45,18 +45,18 @@ template <ck::index_t BlockSize, ...@@ -45,18 +45,18 @@ template <ck::index_t BlockSize,
typename AGridMoveSliceWindowIteratorHacks, typename AGridMoveSliceWindowIteratorHacks,
typename BGridMoveSliceWindowIteratorHacks> typename BGridMoveSliceWindowIteratorHacks>
__host__ float __host__ float
driver_dynamic_contraction_dlops_v1r2(const FloatAB* p_a_grid, driver_contraction_dlops_v1r2(const FloatAB* p_a_grid,
const FloatAB* p_b_grid, const FloatAB* p_b_grid,
FloatC* p_c_grid, FloatC* p_c_grid,
const AGridDesc_GK0_GM0_GM1_GK1& a_grid_desc_gk0_gm0_gm1_gk1, const AGridDesc_GK0_GM0_GM1_GK1& a_grid_desc_gk0_gm0_gm1_gk1,
const BGridDesc_GK0_GN0_GN1_GK1& b_grid_desc_gk0_gn0_gn1_gk1, const BGridDesc_GK0_GN0_GN1_GK1& b_grid_desc_gk0_gn0_gn1_gk1,
const CGridDesc_GM0_GM1_GN0_GN1& c_grid_desc_gm0_gm1_gn0_gn1, const CGridDesc_GM0_GM1_GN0_GN1& c_grid_desc_gm0_gm1_gn0_gn1,
AGridIteratorHacks, AGridIteratorHacks,
BGridIteratorHacks, BGridIteratorHacks,
CGridIteratorHacks, CGridIteratorHacks,
AGridMoveSliceWindowIteratorHacks, AGridMoveSliceWindowIteratorHacks,
BGridMoveSliceWindowIteratorHacks, BGridMoveSliceWindowIteratorHacks,
ck::index_t nrepeat) ck::index_t nrepeat)
{ {
using namespace ck; using namespace ck;
...@@ -70,7 +70,7 @@ driver_dynamic_contraction_dlops_v1r2(const FloatAB* p_a_grid, ...@@ -70,7 +70,7 @@ driver_dynamic_contraction_dlops_v1r2(const FloatAB* p_a_grid,
// GEMM // GEMM
using GridwiseContraction = using GridwiseContraction =
GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1< GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1<
BlockSize, BlockSize,
FloatAB, FloatAB,
FloatAcc, FloatAcc,
...@@ -116,7 +116,7 @@ driver_dynamic_contraction_dlops_v1r2(const FloatAB* p_a_grid, ...@@ -116,7 +116,7 @@ driver_dynamic_contraction_dlops_v1r2(const FloatAB* p_a_grid,
a_grid_desc_gk0_gm0_gm1_gk1, b_grid_desc_gk0_gn0_gn1_gk1, c_grid_desc_gm0_gm1_gn0_gn1)) a_grid_desc_gk0_gm0_gm1_gk1, b_grid_desc_gk0_gn0_gn1_gk1, c_grid_desc_gm0_gm1_gn0_gn1))
{ {
throw std::runtime_error("wrong! " throw std::runtime_error("wrong! "
"GridwiseDynamicContraction_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_" "GridwiseContraction_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_"
"GM0_GM1_GN0_GN1 has invalid setting"); "GM0_GM1_GN0_GN1 has invalid setting");
} }
...@@ -178,7 +178,7 @@ driver_dynamic_contraction_dlops_v1r2(const FloatAB* p_a_grid, ...@@ -178,7 +178,7 @@ driver_dynamic_contraction_dlops_v1r2(const FloatAB* p_a_grid,
if(has_main_k_block_loop && has_double_tail_k_block_loop) if(has_main_k_block_loop && has_double_tail_k_block_loop)
{ {
const auto kernel = kernel_dynamic_contraction_dlops_v1r2< const auto kernel = kernel_contraction_dlops_v1r2<
GridwiseContraction, GridwiseContraction,
FloatAB, FloatAB,
FloatC, FloatC,
...@@ -204,7 +204,7 @@ driver_dynamic_contraction_dlops_v1r2(const FloatAB* p_a_grid, ...@@ -204,7 +204,7 @@ driver_dynamic_contraction_dlops_v1r2(const FloatAB* p_a_grid,
} }
else if(has_main_k_block_loop && !has_double_tail_k_block_loop) else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{ {
const auto kernel = kernel_dynamic_contraction_dlops_v1r2< const auto kernel = kernel_contraction_dlops_v1r2<
GridwiseContraction, GridwiseContraction,
FloatAB, FloatAB,
FloatC, FloatC,
...@@ -230,7 +230,7 @@ driver_dynamic_contraction_dlops_v1r2(const FloatAB* p_a_grid, ...@@ -230,7 +230,7 @@ driver_dynamic_contraction_dlops_v1r2(const FloatAB* p_a_grid,
} }
else if(!has_main_k_block_loop && has_double_tail_k_block_loop) else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{ {
const auto kernel = kernel_dynamic_contraction_dlops_v1r2< const auto kernel = kernel_contraction_dlops_v1r2<
GridwiseContraction, GridwiseContraction,
FloatAB, FloatAB,
FloatC, FloatC,
...@@ -256,7 +256,7 @@ driver_dynamic_contraction_dlops_v1r2(const FloatAB* p_a_grid, ...@@ -256,7 +256,7 @@ driver_dynamic_contraction_dlops_v1r2(const FloatAB* p_a_grid,
} }
else else
{ {
const auto kernel = kernel_dynamic_contraction_dlops_v1r2< const auto kernel = kernel_contraction_dlops_v1r2<
GridwiseContraction, GridwiseContraction,
FloatAB, FloatAB,
FloatC, FloatC,
......
#ifndef DRIVER_DYNAMIC_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_NCHW_KCYX_NKHW_HPP #ifndef DRIVER_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_NCHW_KCYX_NKHW_HPP
#define DRIVER_DYNAMIC_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_NCHW_KCYX_NKHW_HPP #define DRIVER_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_NCHW_KCYX_NKHW_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "gridwise_dynamic_gemm_dlops_v2.hpp" #include "gridwise_gemm_dlops_v2.hpp"
#include "gridwise_operation_wrapper.hpp" #include "gridwise_operation_wrapper.hpp"
template <ck::index_t BlockSize, template <ck::index_t BlockSize,
...@@ -34,9 +34,9 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad ...@@ -34,9 +34,9 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad
typename ConvDilations, typename ConvDilations,
typename InLeftPads, typename InLeftPads,
typename InRightPads> typename InRightPads>
__host__ void Run(const ck::DynamicTensorDescriptor<Wei...>& wei_k_c_y_x_global_desc, __host__ void Run(const ck::TensorDescriptor<Wei...>& wei_k_c_y_x_global_desc,
const ck::DynamicTensorDescriptor<In...>& in_n_c_hi_wi_global_desc, const ck::TensorDescriptor<In...>& in_n_c_hi_wi_global_desc,
const ck::DynamicTensorDescriptor<Out...>& out_n_k0_ho_wo_k1_global_desc, const ck::TensorDescriptor<Out...>& out_n_k0_ho_wo_k1_global_desc,
const ConvStrides& conv_strides, const ConvStrides& conv_strides,
const ConvDilations& conv_dilations, const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads, const InLeftPads& in_left_pads,
...@@ -82,14 +82,14 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad ...@@ -82,14 +82,14 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad
const auto InRightPadW = in_right_pads[I1]; const auto InRightPadW = in_right_pads[I1];
// weight tensor // weight tensor
const auto wei_e_k_global_desc = transform_dynamic_tensor_descriptor( const auto wei_e_k_global_desc = transform_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C * Y * X)), make_naive_tensor_descriptor_packed(make_tuple(K, C * Y * X)),
make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)), make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{})); make_tuple(Sequence<1>{}, Sequence<0>{}));
// input tensor // input tensor
const auto in_n_c_hip_wip_global_desc = transform_dynamic_tensor_descriptor( const auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
in_n_c_hi_wi_global_desc, in_n_c_hi_wi_global_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(N),
make_pass_through_transform(C), make_pass_through_transform(C),
...@@ -98,7 +98,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad ...@@ -98,7 +98,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_c_y_ho_x_wo_global_desc = transform_dynamic_tensor_descriptor( const auto in_n_c_y_ho_x_wo_global_desc = transform_tensor_descriptor(
in_n_c_hip_wip_global_desc, in_n_c_hip_wip_global_desc,
make_tuple( make_tuple(
make_pass_through_transform(N), make_pass_through_transform(N),
...@@ -108,7 +108,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad ...@@ -108,7 +108,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
const auto in_e_n_ho_wo_global_desc = transform_dynamic_tensor_descriptor( const auto in_e_n_ho_wo_global_desc = transform_tensor_descriptor(
in_n_c_y_ho_x_wo_global_desc, in_n_c_y_ho_x_wo_global_desc,
make_tuple(make_merge_transform(make_tuple(C, Y, X)), make_tuple(make_merge_transform(make_tuple(C, Y, X)),
make_pass_through_transform(N), make_pass_through_transform(N),
...@@ -118,8 +118,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad ...@@ -118,8 +118,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
// output tensor // output tensor
const auto out_k_n_ho_wo_global_desc = transform_dynamic_tensor_descriptor( const auto out_k_n_ho_wo_global_desc = transform_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K0, Ho, Wo, K1)), make_naive_tensor_descriptor_packed(make_tuple(N, K0, Ho, Wo, K1)),
make_tuple(make_merge_transform(make_tuple(K0, K1)), make_tuple(make_merge_transform(make_tuple(K0, K1)),
make_pass_through_transform(N), make_pass_through_transform(N),
make_pass_through_transform(Ho), make_pass_through_transform(Ho),
...@@ -169,7 +169,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad ...@@ -169,7 +169,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad
#if 1 #if 1
// GEMM // GEMM
using gridwise_gemm = GridwiseDynamicGemmDlops_km_kn_mn_v3< using gridwise_gemm = GridwiseGemmDlops_km_kn_mn_v3<
BlockSize, BlockSize,
FloatAB, FloatAB,
FloatAcc, FloatAcc,
......
#ifndef DRIVER_DYNAMIC_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_DLOPS_NCHW_KCYX_NKHW_OUTPAD_HPP #ifndef DRIVER_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_DLOPS_NCHW_KCYX_NKHW_OUTPAD_HPP
#define DRIVER_DYNAMIC_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_DLOPS_NCHW_KCYX_NKHW_OUTPAD_HPP #define DRIVER_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_DLOPS_NCHW_KCYX_NKHW_OUTPAD_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "gridwise_dynamic_gemm_dlops_v2.hpp" #include "gridwise_gemm_dlops_v2.hpp"
#include "gridwise_operation_wrapper.hpp" #include "gridwise_operation_wrapper.hpp"
template <ck::index_t BlockSize, template <ck::index_t BlockSize,
...@@ -34,9 +34,9 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -34,9 +34,9 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
typename ConvDilations, typename ConvDilations,
typename InLeftPads, typename InLeftPads,
typename InRightPads> typename InRightPads>
__host__ void Run(const ck::DynamicTensorDescriptor<Wei...>& wei_k_c_y_x_global_desc, __host__ void Run(const ck::TensorDescriptor<Wei...>& wei_k_c_y_x_global_desc,
const ck::DynamicTensorDescriptor<In...>& in_n_c_hi_wi_global_desc, const ck::TensorDescriptor<In...>& in_n_c_hi_wi_global_desc,
const ck::DynamicTensorDescriptor<Out...>& out_n_k0_ho_wo_k1_global_desc, const ck::TensorDescriptor<Out...>& out_n_k0_ho_wo_k1_global_desc,
const ConvStrides& conv_strides, const ConvStrides& conv_strides,
const ConvDilations& conv_dilations, const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads, const InLeftPads& in_left_pads,
...@@ -93,14 +93,14 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -93,14 +93,14 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
<< std::endl; << std::endl;
// weight tensor // weight tensor
const auto wei_e_k_global_desc = transform_dynamic_tensor_descriptor( const auto wei_e_k_global_desc = transform_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C * Y * X)), make_naive_tensor_descriptor_packed(make_tuple(K, C * Y * X)),
make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)), make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{})); make_tuple(Sequence<1>{}, Sequence<0>{}));
// input tensor // input tensor
const auto in_n_c_hip_wip_global_desc = transform_dynamic_tensor_descriptor( const auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
in_n_c_hi_wi_global_desc, in_n_c_hi_wi_global_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(N),
make_pass_through_transform(C), make_pass_through_transform(C),
...@@ -109,7 +109,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -109,7 +109,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_c_y_ho_x_wo_global_desc = transform_dynamic_tensor_descriptor( const auto in_n_c_y_ho_x_wo_global_desc = transform_tensor_descriptor(
in_n_c_hip_wip_global_desc, in_n_c_hip_wip_global_desc,
make_tuple( make_tuple(
make_pass_through_transform(N), make_pass_through_transform(N),
...@@ -119,7 +119,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -119,7 +119,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
const auto in_e_n_ho_wo_global_desc = transform_dynamic_tensor_descriptor( const auto in_e_n_ho_wo_global_desc = transform_tensor_descriptor(
in_n_c_y_ho_x_wo_global_desc, in_n_c_y_ho_x_wo_global_desc,
make_tuple(make_merge_transform(make_tuple(C, Y, X)), make_tuple(make_merge_transform(make_tuple(C, Y, X)),
make_pass_through_transform(N), make_pass_through_transform(N),
...@@ -129,8 +129,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -129,8 +129,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
// output tensor // output tensor
const auto out_k_n_hop_wop_global_desc = transform_dynamic_tensor_descriptor( const auto out_k_n_hop_wop_global_desc = transform_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K0, Ho, Wo, K1)), make_naive_tensor_descriptor_packed(make_tuple(N, K0, Ho, Wo, K1)),
make_tuple(make_merge_transform(make_tuple(K0, K1)), make_tuple(make_merge_transform(make_tuple(K0, K1)),
make_pass_through_transform(N), make_pass_through_transform(N),
make_pad_transform(Ho, 0, OutRightPadH), make_pad_transform(Ho, 0, OutRightPadH),
...@@ -181,7 +181,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -181,7 +181,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
Sequence<0, 0, 0, 0, 0>{})); Sequence<0, 0, 0, 0, 0>{}));
// GEMM // GEMM
using gridwise_gemm = GridwiseDynamicGemmDlops_km_kn_mn_v3< using gridwise_gemm = GridwiseGemmDlops_km_kn_mn_v3<
BlockSize, BlockSize,
FloatAB, FloatAB,
FloatAcc, FloatAcc,
......
#ifndef DRIVER_DYNAMIC_GEMM_DLOPS_V1R2 #ifndef DRIVER_GEMM_DLOPS_V1R2
#define DRIVER_DYNAMIC_GEMM_DLOPS_V1R2 #define DRIVER_GEMM_DLOPS_V1R2
#include "common_header.hpp" #include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "gridwise_dynamic_gemm_dlops_v1r2.hpp" #include "gridwise_gemm_dlops_v1r2.hpp"
template <ck::index_t BlockSize, template <ck::index_t BlockSize,
typename FloatAB, typename FloatAB,
...@@ -48,18 +48,18 @@ template <ck::index_t BlockSize, ...@@ -48,18 +48,18 @@ template <ck::index_t BlockSize,
typename CGridIteratorHacks, typename CGridIteratorHacks,
typename AGridMoveSliceWindowIteratorHacks, typename AGridMoveSliceWindowIteratorHacks,
typename BGridMoveSliceWindowIteratorHacks> typename BGridMoveSliceWindowIteratorHacks>
__host__ float driver_dynamic_gemm_dlops_v1r2(const FloatAB* p_a_grid, __host__ float driver_gemm_dlops_v1r2(const FloatAB* p_a_grid,
const FloatAB* p_b_grid, const FloatAB* p_b_grid,
FloatC* p_c_grid, FloatC* p_c_grid,
const AKMGridDesc& a_k_m_grid_desc, const AKMGridDesc& a_k_m_grid_desc,
const BKNGridDesc& b_k_n_grid_desc, const BKNGridDesc& b_k_n_grid_desc,
const CMNGridDesc& c_m_n_grid_desc, const CMNGridDesc& c_m_n_grid_desc,
AGridIteratorHacks, AGridIteratorHacks,
BGridIteratorHacks, BGridIteratorHacks,
CGridIteratorHacks, CGridIteratorHacks,
AGridMoveSliceWindowIteratorHacks, AGridMoveSliceWindowIteratorHacks,
BGridMoveSliceWindowIteratorHacks, BGridMoveSliceWindowIteratorHacks,
ck::index_t nrepeat) ck::index_t nrepeat)
{ {
using namespace ck; using namespace ck;
...@@ -72,49 +72,48 @@ __host__ float driver_dynamic_gemm_dlops_v1r2(const FloatAB* p_a_grid, ...@@ -72,49 +72,48 @@ __host__ float driver_dynamic_gemm_dlops_v1r2(const FloatAB* p_a_grid,
constexpr auto I5 = Number<5>{}; constexpr auto I5 = Number<5>{};
// GEMM // GEMM
using GridwiseGemm = using GridwiseGemm = GridwiseGemmDlops_km_kn_mn_v1r2<BlockSize,
GridwiseDynamicGemmDlops_km_kn_mn_v1r2<BlockSize, FloatAB,
FloatAB, FloatAcc,
FloatAcc, FloatC,
FloatC, CGlobalMemoryDataOperation,
CGlobalMemoryDataOperation, AKMGridDesc,
AKMGridDesc, BKNGridDesc,
BKNGridDesc, CMNGridDesc,
CMNGridDesc, MPerBlock,
MPerBlock, NPerBlock,
NPerBlock, KPerBlock,
KPerBlock, M1PerThread,
M1PerThread, N1PerThread,
N1PerThread, KPerThread,
KPerThread, M1N1ThreadClusterM10,
M1N1ThreadClusterM10, M1N1ThreadClusterN10,
M1N1ThreadClusterN10, M1N1ThreadClusterM11,
M1N1ThreadClusterM11, M1N1ThreadClusterN11,
M1N1ThreadClusterN11, ABlockTransferThreadSliceLengths_K_M0_M1,
ABlockTransferThreadSliceLengths_K_M0_M1, ABlockTransferThreadClusterLengths_K_M0_M1,
ABlockTransferThreadClusterLengths_K_M0_M1, ABlockTransferThreadClusterArrangeOrder,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder,
ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim,
ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector,
ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_M1,
ABlockTransferDstScalarPerVector_M1, AThreadTransferSrcResetCoordinateAfterRun,
AThreadTransferSrcResetCoordinateAfterRun, BBlockTransferThreadSliceLengths_K_N0_N1,
BBlockTransferThreadSliceLengths_K_N0_N1, BBlockTransferThreadClusterLengths_K_N0_N1,
BBlockTransferThreadClusterLengths_K_N0_N1, BBlockTransferThreadClusterArrangeOrder,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder,
BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim,
BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_N1,
BBlockTransferDstScalarPerVector_N1, BThreadTransferSrcResetCoordinateAfterRun,
BThreadTransferSrcResetCoordinateAfterRun, CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim,
CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector,
CThreadTransferDstScalarPerVector, AGridIteratorHacks,
AGridIteratorHacks, BGridIteratorHacks,
BGridIteratorHacks, CGridIteratorHacks,
CGridIteratorHacks, AGridMoveSliceWindowIteratorHacks,
AGridMoveSliceWindowIteratorHacks, BGridMoveSliceWindowIteratorHacks>;
BGridMoveSliceWindowIteratorHacks>;
const auto M = a_k_m_grid_desc.GetLength(I1); const auto M = a_k_m_grid_desc.GetLength(I1);
const auto N = b_k_n_grid_desc.GetLength(I1); const auto N = b_k_n_grid_desc.GetLength(I1);
...@@ -122,8 +121,7 @@ __host__ float driver_dynamic_gemm_dlops_v1r2(const FloatAB* p_a_grid, ...@@ -122,8 +121,7 @@ __host__ float driver_dynamic_gemm_dlops_v1r2(const FloatAB* p_a_grid,
if(!GridwiseGemm::CheckValidity(a_k_m_grid_desc, b_k_n_grid_desc, c_m_n_grid_desc)) if(!GridwiseGemm::CheckValidity(a_k_m_grid_desc, b_k_n_grid_desc, c_m_n_grid_desc))
{ {
throw std::runtime_error( throw std::runtime_error("wrong! GridwiseGemmDlops_km_kn_mn_v1r2 has invalid setting");
"wrong! GridwiseDynamicGemmDlops_km_kn_mn_v1r2 has invalid setting");
} }
const auto a_k_m0_m1_grid_desc = GridwiseGemm::MakeAKM0M1GridDescriptor(a_k_m_grid_desc); const auto a_k_m0_m1_grid_desc = GridwiseGemm::MakeAKM0M1GridDescriptor(a_k_m_grid_desc);
...@@ -174,15 +172,15 @@ __host__ float driver_dynamic_gemm_dlops_v1r2(const FloatAB* p_a_grid, ...@@ -174,15 +172,15 @@ __host__ float driver_dynamic_gemm_dlops_v1r2(const FloatAB* p_a_grid,
if(has_main_k_block_loop && has_double_tail_k_block_loop) if(has_main_k_block_loop && has_double_tail_k_block_loop)
{ {
const auto kernel = const auto kernel =
kernel_dynamic_gemm_dlops_v1r2<GridwiseGemm, kernel_gemm_dlops_v1r2<GridwiseGemm,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AKM0M1GridDesc>, remove_reference_t<AKM0M1GridDesc>,
remove_reference_t<BKN0N1GridDesc>, remove_reference_t<BKN0N1GridDesc>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>, remove_reference_t<CM0M10M11N0N10N11GridDesc>,
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>, remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
true, true,
true>; true>;
ave_time = launch_and_time_kernel(kernel, ave_time = launch_and_time_kernel(kernel,
nrepeat, nrepeat,
...@@ -200,15 +198,15 @@ __host__ float driver_dynamic_gemm_dlops_v1r2(const FloatAB* p_a_grid, ...@@ -200,15 +198,15 @@ __host__ float driver_dynamic_gemm_dlops_v1r2(const FloatAB* p_a_grid,
else if(has_main_k_block_loop && !has_double_tail_k_block_loop) else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{ {
const auto kernel = const auto kernel =
kernel_dynamic_gemm_dlops_v1r2<GridwiseGemm, kernel_gemm_dlops_v1r2<GridwiseGemm,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AKM0M1GridDesc>, remove_reference_t<AKM0M1GridDesc>,
remove_reference_t<BKN0N1GridDesc>, remove_reference_t<BKN0N1GridDesc>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>, remove_reference_t<CM0M10M11N0N10N11GridDesc>,
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>, remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
true, true,
false>; false>;
ave_time = launch_and_time_kernel(kernel, ave_time = launch_and_time_kernel(kernel,
nrepeat, nrepeat,
...@@ -226,15 +224,15 @@ __host__ float driver_dynamic_gemm_dlops_v1r2(const FloatAB* p_a_grid, ...@@ -226,15 +224,15 @@ __host__ float driver_dynamic_gemm_dlops_v1r2(const FloatAB* p_a_grid,
else if(!has_main_k_block_loop && has_double_tail_k_block_loop) else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{ {
const auto kernel = const auto kernel =
kernel_dynamic_gemm_dlops_v1r2<GridwiseGemm, kernel_gemm_dlops_v1r2<GridwiseGemm,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AKM0M1GridDesc>, remove_reference_t<AKM0M1GridDesc>,
remove_reference_t<BKN0N1GridDesc>, remove_reference_t<BKN0N1GridDesc>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>, remove_reference_t<CM0M10M11N0N10N11GridDesc>,
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>, remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
false, false,
true>; true>;
ave_time = launch_and_time_kernel(kernel, ave_time = launch_and_time_kernel(kernel,
nrepeat, nrepeat,
...@@ -252,15 +250,15 @@ __host__ float driver_dynamic_gemm_dlops_v1r2(const FloatAB* p_a_grid, ...@@ -252,15 +250,15 @@ __host__ float driver_dynamic_gemm_dlops_v1r2(const FloatAB* p_a_grid,
else else
{ {
const auto kernel = const auto kernel =
kernel_dynamic_gemm_dlops_v1r2<GridwiseGemm, kernel_gemm_dlops_v1r2<GridwiseGemm,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AKM0M1GridDesc>, remove_reference_t<AKM0M1GridDesc>,
remove_reference_t<BKN0N1GridDesc>, remove_reference_t<BKN0N1GridDesc>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>, remove_reference_t<CM0M10M11N0N10N11GridDesc>,
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>, remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
false, false,
false>; false>;
ave_time = launch_and_time_kernel(kernel, ave_time = launch_and_time_kernel(kernel,
nrepeat, nrepeat,
...@@ -295,15 +293,15 @@ __host__ float driver_dynamic_gemm_dlops_v1r2(const FloatAB* p_a_grid, ...@@ -295,15 +293,15 @@ __host__ float driver_dynamic_gemm_dlops_v1r2(const FloatAB* p_a_grid,
if(has_main_k_block_loop && has_double_tail_k_block_loop) if(has_main_k_block_loop && has_double_tail_k_block_loop)
{ {
const auto kernel = const auto kernel =
kernel_dynamic_gemm_dlops_v1r2<GridwiseGemm, kernel_gemm_dlops_v1r2<GridwiseGemm,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AKM0M1GridDesc>, remove_reference_t<AKM0M1GridDesc>,
remove_reference_t<BKN0N1GridDesc>, remove_reference_t<BKN0N1GridDesc>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>, remove_reference_t<CM0M10M11N0N10N11GridDesc>,
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>, remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
true, true,
true>; true>;
ave_time = launch_and_time_kernel( ave_time = launch_and_time_kernel(
kernel, kernel,
...@@ -324,15 +322,15 @@ __host__ float driver_dynamic_gemm_dlops_v1r2(const FloatAB* p_a_grid, ...@@ -324,15 +322,15 @@ __host__ float driver_dynamic_gemm_dlops_v1r2(const FloatAB* p_a_grid,
else if(has_main_k_block_loop && !has_double_tail_k_block_loop) else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{ {
const auto kernel = const auto kernel =
kernel_dynamic_gemm_dlops_v1r2<GridwiseGemm, kernel_gemm_dlops_v1r2<GridwiseGemm,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AKM0M1GridDesc>, remove_reference_t<AKM0M1GridDesc>,
remove_reference_t<BKN0N1GridDesc>, remove_reference_t<BKN0N1GridDesc>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>, remove_reference_t<CM0M10M11N0N10N11GridDesc>,
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>, remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
true, true,
false>; false>;
ave_time = launch_and_time_kernel( ave_time = launch_and_time_kernel(
kernel, kernel,
...@@ -353,15 +351,15 @@ __host__ float driver_dynamic_gemm_dlops_v1r2(const FloatAB* p_a_grid, ...@@ -353,15 +351,15 @@ __host__ float driver_dynamic_gemm_dlops_v1r2(const FloatAB* p_a_grid,
else if(!has_main_k_block_loop && has_double_tail_k_block_loop) else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{ {
const auto kernel = const auto kernel =
kernel_dynamic_gemm_dlops_v1r2<GridwiseGemm, kernel_gemm_dlops_v1r2<GridwiseGemm,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AKM0M1GridDesc>, remove_reference_t<AKM0M1GridDesc>,
remove_reference_t<BKN0N1GridDesc>, remove_reference_t<BKN0N1GridDesc>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>, remove_reference_t<CM0M10M11N0N10N11GridDesc>,
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>, remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
false, false,
true>; true>;
ave_time = launch_and_time_kernel( ave_time = launch_and_time_kernel(
kernel, kernel,
...@@ -382,15 +380,15 @@ __host__ float driver_dynamic_gemm_dlops_v1r2(const FloatAB* p_a_grid, ...@@ -382,15 +380,15 @@ __host__ float driver_dynamic_gemm_dlops_v1r2(const FloatAB* p_a_grid,
else else
{ {
const auto kernel = const auto kernel =
kernel_dynamic_gemm_dlops_v1r2<GridwiseGemm, kernel_gemm_dlops_v1r2<GridwiseGemm,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AKM0M1GridDesc>, remove_reference_t<AKM0M1GridDesc>,
remove_reference_t<BKN0N1GridDesc>, remove_reference_t<BKN0N1GridDesc>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>, remove_reference_t<CM0M10M11N0N10N11GridDesc>,
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>, remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
false, false,
false>; false>;
ave_time = launch_and_time_kernel( ave_time = launch_and_time_kernel(
kernel, kernel,
......
#ifndef DRIVER_DYNAMIC_GEMM_DLOPS_V1R3 #ifndef DRIVER_GEMM_DLOPS_V1R3
#define DRIVER_DYNAMIC_GEMM_DLOPS_V1R3 #define DRIVER_GEMM_DLOPS_V1R3
#include "common_header.hpp" #include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "gridwise_dynamic_gemm_dlops_v1r3.hpp" #include "gridwise_gemm_dlops_v1r3.hpp"
template <ck::index_t BlockSize, template <ck::index_t BlockSize,
typename FloatAB, typename FloatAB,
...@@ -44,18 +44,18 @@ template <ck::index_t BlockSize, ...@@ -44,18 +44,18 @@ template <ck::index_t BlockSize,
typename CGridIteratorHacks, typename CGridIteratorHacks,
typename AGridMoveSliceWindowIteratorHacks, typename AGridMoveSliceWindowIteratorHacks,
typename BGridMoveSliceWindowIteratorHacks> typename BGridMoveSliceWindowIteratorHacks>
__host__ float driver_dynamic_gemm_dlops_v1r3(const FloatAB* p_a_grid, __host__ float driver_gemm_dlops_v1r3(const FloatAB* p_a_grid,
const FloatAB* p_b_grid, const FloatAB* p_b_grid,
FloatC* p_c_grid, FloatC* p_c_grid,
const AK0MK1GridDesc& a_k0_m_k1_grid_desc, const AK0MK1GridDesc& a_k0_m_k1_grid_desc,
const BK0NK1GridDesc& b_k0_n_k1_grid_desc, const BK0NK1GridDesc& b_k0_n_k1_grid_desc,
const CMNGridDesc& c_m_n_grid_desc, const CMNGridDesc& c_m_n_grid_desc,
AGridIteratorHacks, AGridIteratorHacks,
BGridIteratorHacks, BGridIteratorHacks,
CGridIteratorHacks, CGridIteratorHacks,
AGridMoveSliceWindowIteratorHacks, AGridMoveSliceWindowIteratorHacks,
BGridMoveSliceWindowIteratorHacks, BGridMoveSliceWindowIteratorHacks,
ck::index_t nrepeat) ck::index_t nrepeat)
{ {
using namespace ck; using namespace ck;
...@@ -69,44 +69,44 @@ __host__ float driver_dynamic_gemm_dlops_v1r3(const FloatAB* p_a_grid, ...@@ -69,44 +69,44 @@ __host__ float driver_dynamic_gemm_dlops_v1r3(const FloatAB* p_a_grid,
// GEMM // GEMM
using GridwiseGemm = using GridwiseGemm =
GridwiseDynamicGemmDlops_km_kn_mn_v1r3<BlockSize, GridwiseGemmDlops_km_kn_mn_v1r3<BlockSize,
FloatAB, FloatAB,
FloatAcc, FloatAcc,
FloatC, FloatC,
CGlobalMemoryDataOperation, CGlobalMemoryDataOperation,
AK0MK1GridDesc, AK0MK1GridDesc,
BK0NK1GridDesc, BK0NK1GridDesc,
CMNGridDesc, CMNGridDesc,
MPerBlock, MPerBlock,
NPerBlock, NPerBlock,
KPerBlock, KPerBlock,
M1PerThread, M1PerThread,
N1PerThread, N1PerThread,
KPerThread, KPerThread,
M1N1ThreadClusterM1Xs, M1N1ThreadClusterM1Xs,
M1N1ThreadClusterN1Xs, M1N1ThreadClusterN1Xs,
ABlockTransferThreadSliceLengths_K0_M0_M1_K1, ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
ABlockTransferThreadClusterLengths_K0_M0_M1_K1, ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
ABlockTransferSrcVectorTensorContiguousDimOrder, ABlockTransferSrcVectorTensorContiguousDimOrder,
ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
BBlockTransferThreadSliceLengths_K0_N0_N1_K1, BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
BBlockTransferThreadClusterLengths_K0_N0_N1_K1, BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
BBlockTransferSrcVectorTensorContiguousDimOrder, BBlockTransferSrcVectorTensorContiguousDimOrder,
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim, CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector, CThreadTransferDstScalarPerVector,
AGridIteratorHacks, AGridIteratorHacks,
BGridIteratorHacks, BGridIteratorHacks,
CGridIteratorHacks, CGridIteratorHacks,
AGridMoveSliceWindowIteratorHacks, AGridMoveSliceWindowIteratorHacks,
BGridMoveSliceWindowIteratorHacks>; BGridMoveSliceWindowIteratorHacks>;
const auto M = a_k0_m_k1_grid_desc.GetLength(I1); const auto M = a_k0_m_k1_grid_desc.GetLength(I1);
const auto N = b_k0_n_k1_grid_desc.GetLength(I1); const auto N = b_k0_n_k1_grid_desc.GetLength(I1);
...@@ -114,8 +114,7 @@ __host__ float driver_dynamic_gemm_dlops_v1r3(const FloatAB* p_a_grid, ...@@ -114,8 +114,7 @@ __host__ float driver_dynamic_gemm_dlops_v1r3(const FloatAB* p_a_grid,
if(!GridwiseGemm::CheckValidity(a_k0_m_k1_grid_desc, b_k0_n_k1_grid_desc, c_m_n_grid_desc)) if(!GridwiseGemm::CheckValidity(a_k0_m_k1_grid_desc, b_k0_n_k1_grid_desc, c_m_n_grid_desc))
{ {
throw std::runtime_error( throw std::runtime_error("wrong! GridwiseGemmDlops_km_kn_mn_v1r3 has invalid setting");
"wrong! GridwiseDynamicGemmDlops_km_kn_mn_v1r3 has invalid setting");
} }
const auto a_k0_m0_m1_k1_grid_desc = const auto a_k0_m0_m1_k1_grid_desc =
...@@ -170,15 +169,15 @@ __host__ float driver_dynamic_gemm_dlops_v1r3(const FloatAB* p_a_grid, ...@@ -170,15 +169,15 @@ __host__ float driver_dynamic_gemm_dlops_v1r3(const FloatAB* p_a_grid,
if(has_main_k_block_loop && has_double_tail_k_block_loop) if(has_main_k_block_loop && has_double_tail_k_block_loop)
{ {
const auto kernel = const auto kernel =
kernel_dynamic_gemm_dlops_v1r3<GridwiseGemm, kernel_gemm_dlops_v1r3<GridwiseGemm,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AK0M0M1K1GridDesc>, remove_reference_t<AK0M0M1K1GridDesc>,
remove_reference_t<BK0N0N1K1GridDesc>, remove_reference_t<BK0N0N1K1GridDesc>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>, remove_reference_t<CM0M10M11N0N10N11GridDesc>,
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>, remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
true, true,
true>; true>;
ave_time = launch_and_time_kernel(kernel, ave_time = launch_and_time_kernel(kernel,
nrepeat, nrepeat,
...@@ -196,15 +195,15 @@ __host__ float driver_dynamic_gemm_dlops_v1r3(const FloatAB* p_a_grid, ...@@ -196,15 +195,15 @@ __host__ float driver_dynamic_gemm_dlops_v1r3(const FloatAB* p_a_grid,
else if(has_main_k_block_loop && !has_double_tail_k_block_loop) else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{ {
const auto kernel = const auto kernel =
kernel_dynamic_gemm_dlops_v1r3<GridwiseGemm, kernel_gemm_dlops_v1r3<GridwiseGemm,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AK0M0M1K1GridDesc>, remove_reference_t<AK0M0M1K1GridDesc>,
remove_reference_t<BK0N0N1K1GridDesc>, remove_reference_t<BK0N0N1K1GridDesc>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>, remove_reference_t<CM0M10M11N0N10N11GridDesc>,
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>, remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
true, true,
false>; false>;
ave_time = launch_and_time_kernel(kernel, ave_time = launch_and_time_kernel(kernel,
nrepeat, nrepeat,
...@@ -222,15 +221,15 @@ __host__ float driver_dynamic_gemm_dlops_v1r3(const FloatAB* p_a_grid, ...@@ -222,15 +221,15 @@ __host__ float driver_dynamic_gemm_dlops_v1r3(const FloatAB* p_a_grid,
else if(!has_main_k_block_loop && has_double_tail_k_block_loop) else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{ {
const auto kernel = const auto kernel =
kernel_dynamic_gemm_dlops_v1r3<GridwiseGemm, kernel_gemm_dlops_v1r3<GridwiseGemm,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AK0M0M1K1GridDesc>, remove_reference_t<AK0M0M1K1GridDesc>,
remove_reference_t<BK0N0N1K1GridDesc>, remove_reference_t<BK0N0N1K1GridDesc>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>, remove_reference_t<CM0M10M11N0N10N11GridDesc>,
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>, remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
false, false,
true>; true>;
ave_time = launch_and_time_kernel(kernel, ave_time = launch_and_time_kernel(kernel,
nrepeat, nrepeat,
...@@ -248,15 +247,15 @@ __host__ float driver_dynamic_gemm_dlops_v1r3(const FloatAB* p_a_grid, ...@@ -248,15 +247,15 @@ __host__ float driver_dynamic_gemm_dlops_v1r3(const FloatAB* p_a_grid,
else else
{ {
const auto kernel = const auto kernel =
kernel_dynamic_gemm_dlops_v1r3<GridwiseGemm, kernel_gemm_dlops_v1r3<GridwiseGemm,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AK0M0M1K1GridDesc>, remove_reference_t<AK0M0M1K1GridDesc>,
remove_reference_t<BK0N0N1K1GridDesc>, remove_reference_t<BK0N0N1K1GridDesc>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>, remove_reference_t<CM0M10M11N0N10N11GridDesc>,
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>, remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
false, false,
false>; false>;
ave_time = launch_and_time_kernel(kernel, ave_time = launch_and_time_kernel(kernel,
nrepeat, nrepeat,
...@@ -291,15 +290,15 @@ __host__ float driver_dynamic_gemm_dlops_v1r3(const FloatAB* p_a_grid, ...@@ -291,15 +290,15 @@ __host__ float driver_dynamic_gemm_dlops_v1r3(const FloatAB* p_a_grid,
if(has_main_k_block_loop && has_double_tail_k_block_loop) if(has_main_k_block_loop && has_double_tail_k_block_loop)
{ {
const auto kernel = const auto kernel =
kernel_dynamic_gemm_dlops_v1r3<GridwiseGemm, kernel_gemm_dlops_v1r3<GridwiseGemm,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AK0M0M1K1GridDesc>, remove_reference_t<AK0M0M1K1GridDesc>,
remove_reference_t<BK0N0N1K1GridDesc>, remove_reference_t<BK0N0N1K1GridDesc>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>, remove_reference_t<CM0M10M11N0N10N11GridDesc>,
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>, remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
true, true,
true>; true>;
ave_time = launch_and_time_kernel( ave_time = launch_and_time_kernel(
kernel, kernel,
...@@ -322,15 +321,15 @@ __host__ float driver_dynamic_gemm_dlops_v1r3(const FloatAB* p_a_grid, ...@@ -322,15 +321,15 @@ __host__ float driver_dynamic_gemm_dlops_v1r3(const FloatAB* p_a_grid,
else if(has_main_k_block_loop && !has_double_tail_k_block_loop) else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{ {
const auto kernel = const auto kernel =
kernel_dynamic_gemm_dlops_v1r3<GridwiseGemm, kernel_gemm_dlops_v1r3<GridwiseGemm,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AK0M0M1K1GridDesc>, remove_reference_t<AK0M0M1K1GridDesc>,
remove_reference_t<BK0N0N1K1GridDesc>, remove_reference_t<BK0N0N1K1GridDesc>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>, remove_reference_t<CM0M10M11N0N10N11GridDesc>,
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>, remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
true, true,
false>; false>;
ave_time = launch_and_time_kernel( ave_time = launch_and_time_kernel(
kernel, kernel,
...@@ -353,15 +352,15 @@ __host__ float driver_dynamic_gemm_dlops_v1r3(const FloatAB* p_a_grid, ...@@ -353,15 +352,15 @@ __host__ float driver_dynamic_gemm_dlops_v1r3(const FloatAB* p_a_grid,
else if(!has_main_k_block_loop && has_double_tail_k_block_loop) else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{ {
const auto kernel = const auto kernel =
kernel_dynamic_gemm_dlops_v1r3<GridwiseGemm, kernel_gemm_dlops_v1r3<GridwiseGemm,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AK0M0M1K1GridDesc>, remove_reference_t<AK0M0M1K1GridDesc>,
remove_reference_t<BK0N0N1K1GridDesc>, remove_reference_t<BK0N0N1K1GridDesc>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>, remove_reference_t<CM0M10M11N0N10N11GridDesc>,
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>, remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
false, false,
true>; true>;
ave_time = launch_and_time_kernel( ave_time = launch_and_time_kernel(
kernel, kernel,
...@@ -384,15 +383,15 @@ __host__ float driver_dynamic_gemm_dlops_v1r3(const FloatAB* p_a_grid, ...@@ -384,15 +383,15 @@ __host__ float driver_dynamic_gemm_dlops_v1r3(const FloatAB* p_a_grid,
else else
{ {
const auto kernel = const auto kernel =
kernel_dynamic_gemm_dlops_v1r3<GridwiseGemm, kernel_gemm_dlops_v1r3<GridwiseGemm,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AK0M0M1K1GridDesc>, remove_reference_t<AK0M0M1K1GridDesc>,
remove_reference_t<BK0N0N1K1GridDesc>, remove_reference_t<BK0N0N1K1GridDesc>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>, remove_reference_t<CM0M10M11N0N10N11GridDesc>,
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>, remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
false, false,
false>; false>;
ave_time = launch_and_time_kernel( ave_time = launch_and_time_kernel(
kernel, kernel,
......
#ifndef DRIVER_DYNAMIC_GEMM_XDLOPS_V2R3 #ifndef DRIVER_GEMM_XDLOPS_V2R3
#define DRIVER_DYNAMIC_GEMM_XDLOPS_V2R3 #define DRIVER_GEMM_XDLOPS_V2R3
#include "common_header.hpp" #include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "gridwise_dynamic_gemm_xdlops_v2r3.hpp" #include "gridwise_gemm_xdlops_v2r3.hpp"
template <ck::index_t BlockSize, template <ck::index_t BlockSize,
typename FloatAB, typename FloatAB,
...@@ -47,18 +47,18 @@ template <ck::index_t BlockSize, ...@@ -47,18 +47,18 @@ template <ck::index_t BlockSize,
typename AGridMoveSliceWindowIteratorHacks, typename AGridMoveSliceWindowIteratorHacks,
typename BGridMoveSliceWindowIteratorHacks, typename BGridMoveSliceWindowIteratorHacks,
bool CAccessOrderMRepeatNRepeat> bool CAccessOrderMRepeatNRepeat>
__host__ float driver_dynamic_gemm_xdlops_v2r3(const FloatAB* p_a_grid, __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
const FloatAB* p_b_grid, const FloatAB* p_b_grid,
FloatC* p_c_grid, FloatC* p_c_grid,
const AK0MK1GridDesc& a_k0_m_k1_grid_desc, const AK0MK1GridDesc& a_k0_m_k1_grid_desc,
const BK0NK1GridDesc& b_k0_n_k1_grid_desc, const BK0NK1GridDesc& b_k0_n_k1_grid_desc,
const CMNGridDesc& c_m_n_grid_desc, const CMNGridDesc& c_m_n_grid_desc,
AGridIteratorHacks, AGridIteratorHacks,
BGridIteratorHacks, BGridIteratorHacks,
CGridIteratorHacks, CGridIteratorHacks,
AGridMoveSliceWindowIteratorHacks, AGridMoveSliceWindowIteratorHacks,
BGridMoveSliceWindowIteratorHacks, BGridMoveSliceWindowIteratorHacks,
ck::index_t nrepeat) ck::index_t nrepeat)
{ {
using namespace ck; using namespace ck;
...@@ -68,47 +68,47 @@ __host__ float driver_dynamic_gemm_xdlops_v2r3(const FloatAB* p_a_grid, ...@@ -68,47 +68,47 @@ __host__ float driver_dynamic_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
using GridwiseGemm = using GridwiseGemm =
GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3<BlockSize, GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3<BlockSize,
FloatAB, FloatAB,
FloatAcc, FloatAcc,
FloatC, FloatC,
CGlobalMemoryDataOperation, CGlobalMemoryDataOperation,
AK0MK1GridDesc, AK0MK1GridDesc,
BK0NK1GridDesc, BK0NK1GridDesc,
CMNGridDesc, CMNGridDesc,
MPerBlock, MPerBlock,
NPerBlock, NPerBlock,
KPerBlock, KPerBlock,
MPerWave, MPerWave,
NPerWave, NPerWave,
K1, K1,
MRepeat, MRepeat,
NRepeat, NRepeat,
ABlockTransferThreadSliceLengths_K0_M_K1, ABlockTransferThreadSliceLengths_K0_M_K1,
ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim, ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector, ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1, ABlockTransferDstScalarPerVector_K1,
AThreadTransferSrcResetCoordinateAfterRun, AThreadTransferSrcResetCoordinateAfterRun,
BBlockTransferThreadSliceLengths_K0_N_K1, BBlockTransferThreadSliceLengths_K0_N_K1,
BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim, BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1, BBlockTransferDstScalarPerVector_K1,
BThreadTransferSrcResetCoordinateAfterRun, BThreadTransferSrcResetCoordinateAfterRun,
CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim, CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector, CThreadTransferDstScalarPerVector,
AGridIteratorHacks, AGridIteratorHacks,
BGridIteratorHacks, BGridIteratorHacks,
CGridIteratorHacks, CGridIteratorHacks,
AGridMoveSliceWindowIteratorHacks, AGridMoveSliceWindowIteratorHacks,
BGridMoveSliceWindowIteratorHacks, BGridMoveSliceWindowIteratorHacks,
CAccessOrderMRepeatNRepeat>; CAccessOrderMRepeatNRepeat>;
{ {
std::cout << "a_k0_m_k1_grid_desc{" << a_k0_m_k1_grid_desc.GetLength(I0) << ", " std::cout << "a_k0_m_k1_grid_desc{" << a_k0_m_k1_grid_desc.GetLength(I0) << ", "
...@@ -126,7 +126,7 @@ __host__ float driver_dynamic_gemm_xdlops_v2r3(const FloatAB* p_a_grid, ...@@ -126,7 +126,7 @@ __host__ float driver_dynamic_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
if(!GridwiseGemm::CheckValidity(a_k0_m_k1_grid_desc, b_k0_n_k1_grid_desc, c_m_n_grid_desc)) if(!GridwiseGemm::CheckValidity(a_k0_m_k1_grid_desc, b_k0_n_k1_grid_desc, c_m_n_grid_desc))
{ {
throw std::runtime_error( throw std::runtime_error(
"wrong! GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting"); "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting");
} }
const auto c_m0_m1_m2_n_grid_desc = GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc); const auto c_m0_m1_m2_n_grid_desc = GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc);
...@@ -139,13 +139,13 @@ __host__ float driver_dynamic_gemm_xdlops_v2r3(const FloatAB* p_a_grid, ...@@ -139,13 +139,13 @@ __host__ float driver_dynamic_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
const index_t grid_size = GridwiseGemm::CalculateGridSize(c_m_n_grid_desc); const index_t grid_size = GridwiseGemm::CalculateGridSize(c_m_n_grid_desc);
const auto kernel = kernel_dynamic_gemm_xdlops_v2r3<GridwiseGemm, const auto kernel = kernel_gemm_xdlops_v2r3<GridwiseGemm,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AK0MK1GridDesc>, remove_reference_t<AK0MK1GridDesc>,
remove_reference_t<BK0NK1GridDesc>, remove_reference_t<BK0NK1GridDesc>,
remove_reference_t<CM0M1M2NGridDesc>, remove_reference_t<CM0M1M2NGridDesc>,
remove_reference_t<CBlockClusterAdaptor>>; remove_reference_t<CBlockClusterAdaptor>>;
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE #if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
float ave_time = launch_and_time_kernel(kernel, float ave_time = launch_and_time_kernel(kernel,
......
...@@ -12,10 +12,10 @@ ...@@ -12,10 +12,10 @@
#include "conv_common.hpp" #include "conv_common.hpp"
#include "host_conv_bwd_data.hpp" #include "host_conv_bwd_data.hpp"
#include "device_tensor.hpp" #include "device_tensor.hpp"
#include "device_dynamic_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp" #include "device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp"
#include "device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp" #include "device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp"
#define USE_DYNAMIC_MODE 1 #define USE_MODE 1
#define USE_CONV_BWD_V4R1_XDL_NHWC 1 #define USE_CONV_BWD_V4R1_XDL_NHWC 1
#define USE_CONV_BWD_V4R1R2_XDL_NHWC 1 #define USE_CONV_BWD_V4R1R2_XDL_NHWC 1
...@@ -37,7 +37,7 @@ int main(int argc, char* argv[]) ...@@ -37,7 +37,7 @@ int main(int argc, char* argv[])
constexpr auto I5 = Number<5>{}; constexpr auto I5 = Number<5>{};
constexpr auto I6 = Number<6>{}; constexpr auto I6 = Number<6>{};
#if USE_DYNAMIC_MODE #if USE_MODE
// dynamic mode // dynamic mode
if(argc != 22) if(argc != 22)
{ {
...@@ -212,7 +212,7 @@ int main(int argc, char* argv[]) ...@@ -212,7 +212,7 @@ int main(int argc, char* argv[])
} }
auto f_make_for_device_nhwc = [&]() { auto f_make_for_device_nhwc = [&]() {
#if USE_DYNAMIC_MODE #if USE_MODE
const auto in_lengths_dev = make_tuple(N, Hi, Wi, C); const auto in_lengths_dev = make_tuple(N, Hi, Wi, C);
const auto wei_lengths_dev = make_tuple(K, Y, X, C); const auto wei_lengths_dev = make_tuple(K, Y, X, C);
const auto out_lengths_dev = make_tuple(N, Ho, Wo, K); const auto out_lengths_dev = make_tuple(N, Ho, Wo, K);
...@@ -253,20 +253,20 @@ int main(int argc, char* argv[]) ...@@ -253,20 +253,20 @@ int main(int argc, char* argv[])
const auto tmp = f_make_for_device_nhwc(); const auto tmp = f_make_for_device_nhwc();
device_dynamic_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk< device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk<in_data_t,
in_data_t, acc_data_t,
acc_data_t, out_data_t>(
out_data_t>(tmp[I0], tmp[I0],
tmp[I1], tmp[I1],
tmp[I2], tmp[I2],
tmp[I3], tmp[I3],
tmp[I4], tmp[I4],
tmp[I5], tmp[I5],
tmp[I6], tmp[I6],
in_device, in_device,
wei, wei,
out, out,
nrepeat); nrepeat);
} }
#endif #endif
...@@ -280,20 +280,20 @@ int main(int argc, char* argv[]) ...@@ -280,20 +280,20 @@ int main(int argc, char* argv[])
const auto tmp = f_make_for_device_nhwc(); const auto tmp = f_make_for_device_nhwc();
device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk< device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk<in_data_t,
in_data_t, acc_data_t,
acc_data_t, out_data_t>(
out_data_t>(tmp[I0], tmp[I0],
tmp[I1], tmp[I1],
tmp[I2], tmp[I2],
tmp[I3], tmp[I3],
tmp[I4], tmp[I4],
tmp[I5], tmp[I5],
tmp[I6], tmp[I6],
in_device, in_device,
wei, wei,
out, out,
nrepeat); nrepeat);
} }
#endif #endif
......
...@@ -12,14 +12,14 @@ ...@@ -12,14 +12,14 @@
#include "conv_common.hpp" #include "conv_common.hpp"
#include "host_conv.hpp" #include "host_conv.hpp"
#include "device_tensor.hpp" #include "device_tensor.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp" #include "device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk.hpp" #include "device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.hpp" #include "device_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp" #include "device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp" #include "device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp" #include "device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp"
#define USE_DYNAMIC_MODE 1 #define USE_MODE 1
#define USE_CONV_FWD_V4R4_NCHW 0 #define USE_CONV_FWD_V4R4_NCHW 0
#define USE_CONV_FWD_V4R4R2_NHWC 1 #define USE_CONV_FWD_V4R4R2_NHWC 1
#define USE_CONV_FWD_V6R1_NCHW 1 #define USE_CONV_FWD_V6R1_NCHW 1
...@@ -49,7 +49,7 @@ int main(int argc, char* argv[]) ...@@ -49,7 +49,7 @@ int main(int argc, char* argv[])
constexpr auto I5 = Number<5>{}; constexpr auto I5 = Number<5>{};
constexpr auto I6 = Number<6>{}; constexpr auto I6 = Number<6>{};
#if USE_DYNAMIC_MODE #if USE_MODE
// dynamic mode // dynamic mode
if(argc != 22) if(argc != 22)
{ {
...@@ -228,7 +228,7 @@ int main(int argc, char* argv[]) ...@@ -228,7 +228,7 @@ int main(int argc, char* argv[])
} }
auto f_make_for_device_nchw = [&]() { auto f_make_for_device_nchw = [&]() {
#if USE_DYNAMIC_MODE #if USE_MODE
const auto in_lengths_dev = make_tuple(N, C, Hi, Wi); const auto in_lengths_dev = make_tuple(N, C, Hi, Wi);
const auto wei_lengths_dev = make_tuple(K, C, Y, X); const auto wei_lengths_dev = make_tuple(K, C, Y, X);
const auto out_lengths_dev = make_tuple(N, K, Ho, Wo); const auto out_lengths_dev = make_tuple(N, K, Ho, Wo);
...@@ -260,7 +260,7 @@ int main(int argc, char* argv[]) ...@@ -260,7 +260,7 @@ int main(int argc, char* argv[])
}; };
auto f_make_for_device_nhwc = [&]() { auto f_make_for_device_nhwc = [&]() {
#if USE_DYNAMIC_MODE #if USE_MODE
const auto in_lengths_dev = make_tuple(N, Hi, Wi, C); const auto in_lengths_dev = make_tuple(N, Hi, Wi, C);
const auto wei_lengths_dev = make_tuple(K, Y, X, C); const auto wei_lengths_dev = make_tuple(K, Y, X, C);
const auto out_lengths_dev = make_tuple(N, Ho, Wo, K); const auto out_lengths_dev = make_tuple(N, Ho, Wo, K);
...@@ -301,20 +301,19 @@ int main(int argc, char* argv[]) ...@@ -301,20 +301,19 @@ int main(int argc, char* argv[])
const auto tmp = f_make_for_device_nchw(); const auto tmp = f_make_for_device_nchw();
device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw<in_data_t, device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw<in_data_t,
acc_data_t, acc_data_t,
out_data_t>( out_data_t>(tmp[I0],
tmp[I0], tmp[I1],
tmp[I1], tmp[I2],
tmp[I2], tmp[I3],
tmp[I3], tmp[I4],
tmp[I4], tmp[I5],
tmp[I5], tmp[I6],
tmp[I6], in,
in, wei,
wei, out_device,
out_device, nrepeat);
nrepeat);
} }
#endif #endif
...@@ -328,20 +327,19 @@ int main(int argc, char* argv[]) ...@@ -328,20 +327,19 @@ int main(int argc, char* argv[])
const auto tmp = f_make_for_device_nhwc(); const auto tmp = f_make_for_device_nhwc();
device_dynamic_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk<in_data_t, device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk<in_data_t,
acc_data_t, acc_data_t,
out_data_t>( out_data_t>(tmp[I0],
tmp[I0], tmp[I1],
tmp[I1], tmp[I2],
tmp[I2], tmp[I3],
tmp[I3], tmp[I4],
tmp[I4], tmp[I5],
tmp[I5], tmp[I6],
tmp[I6], in,
in, wei,
wei, out_device,
out_device, nrepeat);
nrepeat);
} }
#endif #endif
...@@ -355,20 +353,19 @@ int main(int argc, char* argv[]) ...@@ -355,20 +353,19 @@ int main(int argc, char* argv[])
const auto tmp = f_make_for_device_nchw(); const auto tmp = f_make_for_device_nchw();
device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw<in_data_t, device_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw<in_data_t,
acc_data_t, acc_data_t,
out_data_t>( out_data_t>(tmp[I0],
tmp[I0], tmp[I1],
tmp[I1], tmp[I2],
tmp[I2], tmp[I3],
tmp[I3], tmp[I4],
tmp[I4], tmp[I5],
tmp[I5], tmp[I6],
tmp[I6], in,
in, wei,
wei, out_device,
out_device, nrepeat);
nrepeat);
} }
#endif #endif
...@@ -382,21 +379,20 @@ int main(int argc, char* argv[]) ...@@ -382,21 +379,20 @@ int main(int argc, char* argv[])
const auto tmp = f_make_for_device_nchw(); const auto tmp = f_make_for_device_nchw();
device_dynamic_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw<in_data_t, device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw<in_data_t,
16, 16,
acc_data_t, acc_data_t,
out_data_t>( out_data_t>(tmp[I0],
tmp[I0], tmp[I1],
tmp[I1], tmp[I2],
tmp[I2], tmp[I3],
tmp[I3], tmp[I4],
tmp[I4], tmp[I5],
tmp[I5], tmp[I6],
tmp[I6], in,
in, wei,
wei, out_device,
out_device, nrepeat);
nrepeat);
} }
#endif #endif
...@@ -410,9 +406,9 @@ int main(int argc, char* argv[]) ...@@ -410,9 +406,9 @@ int main(int argc, char* argv[])
const auto tmp = f_make_for_device_nchw(); const auto tmp = f_make_for_device_nchw();
device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw<in_data_t, device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw<in_data_t,
acc_data_t, acc_data_t,
out_data_t>( out_data_t>(
tmp[I0], tmp[I0],
tmp[I1], tmp[I1],
tmp[I2], tmp[I2],
...@@ -437,9 +433,9 @@ int main(int argc, char* argv[]) ...@@ -437,9 +433,9 @@ int main(int argc, char* argv[])
const auto tmp = f_make_for_device_nhwc(); const auto tmp = f_make_for_device_nhwc();
device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk<in_data_t, device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk<in_data_t,
acc_data_t, acc_data_t,
out_data_t>( out_data_t>(
tmp[I0], tmp[I0],
tmp[I1], tmp[I1],
tmp[I2], tmp[I2],
......
#ifndef CONV_COMMON_HPP #ifndef CONV_COMMON_HPP
#define CONV_COMMON_HPP #define CONV_COMMON_HPP
#include "dynamic_tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
enum ConvTensorLayout enum ConvTensorLayout
{ {
...@@ -19,8 +19,8 @@ template <typename... InDesc, ...@@ -19,8 +19,8 @@ template <typename... InDesc,
typename LeftPads, typename LeftPads,
typename RightPads> typename RightPads>
constexpr auto get_convolution_output_default_4d_tensor_descriptor( constexpr auto get_convolution_output_default_4d_tensor_descriptor(
const ck::DynamicTensorDescriptor<InDesc...>& in_desc, const ck::TensorDescriptor<InDesc...>& in_desc,
const ck::DynamicTensorDescriptor<WeiDesc...>& wei_desc, const ck::TensorDescriptor<WeiDesc...>& wei_desc,
const ConvStrides& conv_strides, const ConvStrides& conv_strides,
const ConvDilations conv_dilations, const ConvDilations conv_dilations,
const LeftPads& left_pads, const LeftPads& left_pads,
...@@ -57,7 +57,7 @@ constexpr auto get_convolution_output_default_4d_tensor_descriptor( ...@@ -57,7 +57,7 @@ constexpr auto get_convolution_output_default_4d_tensor_descriptor(
const auto Ho = (Hi + LeftPadH + RightPadH - YEff) / conv_strides[I0] + I1; const auto Ho = (Hi + LeftPadH + RightPadH - YEff) / conv_strides[I0] + I1;
const auto Wo = (Wi + LeftPadW + RightPadW - XEff) / conv_strides[I1] + I1; const auto Wo = (Wi + LeftPadW + RightPadW - XEff) / conv_strides[I1] + I1;
return make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho, Wo)); return make_naive_tensor_descriptor_packed(make_tuple(N, K, Ho, Wo));
} }
template <class InDesc, class WeiDesc, class OutDesc> template <class InDesc, class WeiDesc, class OutDesc>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment