Unverified Commit 38a90b6e authored by Chao Liu's avatar Chao Liu Committed by GitHub
Browse files

Merge pull request #43 from ROCmSoftwarePlatform/develop

Merge develop into master
parents 88833bd9 c3018794
#ifndef DRIVER_GEMM_XDLOPS_V2R4
#define DRIVER_GEMM_XDLOPS_V2R4
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_xdlops_v2r4.hpp"
template <ck::index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename FloatC,
ck::InMemoryDataOperationEnum_t CGlobalMemoryDataOperation,
typename ABK0MK1GridDesc,
typename BBK0NK1GridDesc,
typename CMNGridDesc,
ck::index_t MPerBlock,
ck::index_t NPerBlock,
ck::index_t KPerBlock,
ck::index_t MPerXDL,
ck::index_t NPerXDL,
ck::index_t K1,
ck::index_t MRepeat,
ck::index_t NRepeat,
typename ABlockTransferThreadSliceLengths_K0_M_K1,
typename ABlockTransferThreadClusterLengths_K0_M_K1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
ck::index_t ABlockTransferSrcVectorDim,
ck::index_t ABlockTransferSrcScalarPerVector,
ck::index_t ABlockTransferDstScalarPerVector_K1,
bool AThreadTransferSrcResetCoordinateAfterRun,
typename BBlockTransferThreadSliceLengths_K0_N_K1,
typename BBlockTransferThreadClusterLengths_K0_N_K1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
ck::index_t BBlockTransferSrcVectorDim,
ck::index_t BBlockTransferSrcScalarPerVector,
ck::index_t BBlockTransferDstScalarPerVector_K1,
bool BThreadTransferSrcResetCoordinateAfterRun,
typename CThreadTransferSrcDstAccessOrder,
ck::index_t CThreadTransferSrcDstVectorDim,
ck::index_t CThreadTransferDstScalarPerVector,
typename AGridStepHacks,
typename BGridStepHacks,
typename CGridStepHacks,
typename AGridMoveSliceWindowStepHacks,
typename BGridMoveSliceWindowStepHacks,
bool CAccessOrderMRepeatNRepeat,
bool ABlockLdsAddExtraM,
bool BBlockLdsAddExtraN>
__host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid,
const FloatAB* p_b_grid,
FloatC* p_c_grid,
const ABK0MK1GridDesc& a_b_k0_m_k1_grid_desc,
const BBK0NK1GridDesc& b_b_k0_n_k1_grid_desc,
const CMNGridDesc& c_m_n_grid_desc,
ck::index_t M01,
ck::index_t N01,
AGridStepHacks,
BGridStepHacks,
CGridStepHacks,
AGridMoveSliceWindowStepHacks,
BGridMoveSliceWindowStepHacks,
ck::index_t nrepeat)
{
using namespace ck;
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
using GridwiseGemm =
GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4<BlockSize,
FloatAB,
FloatAcc,
FloatC,
CGlobalMemoryDataOperation,
ABK0MK1GridDesc,
BBK0NK1GridDesc,
CMNGridDesc,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
K1,
MRepeat,
NRepeat,
ABlockTransferThreadSliceLengths_K0_M_K1,
ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1,
AThreadTransferSrcResetCoordinateAfterRun,
BBlockTransferThreadSliceLengths_K0_N_K1,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1,
BThreadTransferSrcResetCoordinateAfterRun,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
AGridStepHacks,
BGridStepHacks,
CGridStepHacks,
AGridMoveSliceWindowStepHacks,
BGridMoveSliceWindowStepHacks,
CAccessOrderMRepeatNRepeat,
ABlockLdsAddExtraM,
BBlockLdsAddExtraN>;
{
std::cout << "a_b_k0_m_k1_grid_desc{" << a_b_k0_m_k1_grid_desc.GetLength(I0) << ", "
<< a_b_k0_m_k1_grid_desc.GetLength(I1) << ", "
<< a_b_k0_m_k1_grid_desc.GetLength(I2) << ", "
<< a_b_k0_m_k1_grid_desc.GetLength(I3) << "}" << std::endl;
std::cout << "b_b_k0_n_k1_grid_desc{" << b_b_k0_n_k1_grid_desc.GetLength(I0) << ", "
<< b_b_k0_n_k1_grid_desc.GetLength(I1) << ", "
<< b_b_k0_n_k1_grid_desc.GetLength(I2) << ", "
<< b_b_k0_n_k1_grid_desc.GetLength(I3) << "}" << std::endl;
std::cout << "c_m_n_grid_desc{ " << c_m_n_grid_desc.GetLength(I0) << ", "
<< c_m_n_grid_desc.GetLength(I1) << "}" << std::endl;
}
if(!GridwiseGemm::CheckValidity(
a_b_k0_m_k1_grid_desc, b_b_k0_n_k1_grid_desc, c_m_n_grid_desc, M01, N01))
{
throw std::runtime_error(
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r4 has invalid setting");
}
const auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc =
GridwiseGemm::MakeCM0N0M1N1M2M3M4N2GridDescriptor(c_m_n_grid_desc);
using CM0N0M1N1M2M3M4N2GridDesc = decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc);
const auto KBatch = a_b_k0_m_k1_grid_desc.GetLength(I0);
const auto c_block_cluster_adaptor =
GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc, M01, N01, KBatch);
using CBlockClusterAdaptor = decltype(c_block_cluster_adaptor);
const index_t grid_size = GridwiseGemm::CalculateGridSize(c_m_n_grid_desc, KBatch);
{
std::cout << "gridSize : " << grid_size << std::endl;
}
const auto kernel = kernel_gemm_xdlops_v2r4<GridwiseGemm,
FloatAB,
FloatC,
remove_reference_t<ABK0MK1GridDesc>,
remove_reference_t<BBK0NK1GridDesc>,
remove_reference_t<CM0N0M1N1M2M3M4N2GridDesc>,
remove_reference_t<CBlockClusterAdaptor>>;
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
float ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
p_a_grid,
p_b_grid,
p_c_grid,
a_b_k0_m_k1_grid_desc,
b_b_k0_n_k1_grid_desc,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
c_block_cluster_adaptor);
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
DeviceMem a_b_k0_m_k1_grid_desc_dev_buf(sizeof(ABK0MK1GridDesc));
DeviceMem b_b_k0_n_k1_grid_desc_dev_buf(sizeof(BBK0NK1GridDesc));
DeviceMem c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf(sizeof(CM0N0M1N1M2M3M4N2GridDesc));
DeviceMem c_block_cluster_adaptor_dev_buf(sizeof(CBlockClusterAdaptor));
a_b_k0_m_k1_grid_desc_dev_buf.ToDevice(&a_b_k0_m_k1_grid_desc);
b_b_k0_n_k1_grid_desc_dev_buf.ToDevice(&b_b_k0_n_k1_grid_desc);
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf.ToDevice(&c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc);
c_block_cluster_adaptor_dev_buf.ToDevice(&c_block_cluster_adaptor);
float ave_time = launch_and_time_kernel(
kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
p_a_grid,
p_b_grid,
p_c_grid,
cast_pointer_to_constant_address_space(a_b_k0_m_k1_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(b_b_k0_n_k1_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(c_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
#endif
return ave_time;
}
#endif
......@@ -5,6 +5,7 @@
#include <stdlib.h>
#include <half.hpp>
#include "config.hpp"
#include "debug.hpp"
#include "print.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
......@@ -14,15 +15,16 @@
#include "device_tensor.hpp"
#include "device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp"
#include "device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp"
#include "device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk_1x1.hpp"
#define USE_MODE 1
#define USE_CONV_BWD_V4R1_XDL_NHWC 1
#define USE_CONV_BWD_V4R1_XDL_NHWC 0
#define USE_CONV_BWD_V4R1R2_XDL_NHWC 1
enum ConvBackwardDataAlgo
{
V4R1XDLNHWC,
V4R1R2XDLNHWC,
V4R1XDLNHWC, // 0
V4R1R2XDLNHWC, // 1
};
int main(int argc, char* argv[])
......@@ -280,20 +282,43 @@ int main(int argc, char* argv[])
const auto tmp = f_make_for_device_nhwc();
device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk<in_data_t,
acc_data_t,
out_data_t>(
tmp[I0],
tmp[I1],
tmp[I2],
tmp[I3],
tmp[I4],
tmp[I5],
tmp[I6],
in_device,
wei,
out,
nrepeat);
if(Y == 1 && X == 1 && in_left_pad_h == 0 && in_left_pad_w == 0 && in_right_pad_h == 0 &&
in_right_pad_w == 0)
{
device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk_1x1<
in_data_t,
acc_data_t,
out_data_t>(tmp[I0],
tmp[I1],
tmp[I2],
tmp[I3],
tmp[I4],
tmp[I5],
tmp[I6],
in_device,
wei,
out,
nrepeat);
}
else
{
#if 1
device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk<in_data_t,
acc_data_t,
out_data_t>(
tmp[I0],
tmp[I1],
tmp[I2],
tmp[I3],
tmp[I4],
tmp[I5],
tmp[I6],
in_device,
wei,
out,
nrepeat);
#endif
}
}
#endif
......
......@@ -5,6 +5,7 @@
#include <stdlib.h>
#include <half.hpp>
#include "config.hpp"
#include "debug.hpp"
#include "print.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
......@@ -24,7 +25,7 @@
#define USE_CONV_FWD_V4R4R2_NHWC 0
#define USE_CONV_FWD_V6R1_NCHW 0
#define USE_CONV_FWD_V5R1_NCHW 0
#define USE_CONV_FWD_V4R4R2_XDL_NCHW 1
#define USE_CONV_FWD_V4R4R2_XDL_NCHW 0
#define USE_CONV_FWD_V4R4R4_XDL_NHWC 1
enum ConvForwardAlgo
......
......@@ -5,6 +5,7 @@
#include <stdlib.h>
#include <half.hpp>
#include "config.hpp"
#include "debug.hpp"
#include "print.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
......@@ -13,13 +14,25 @@
#include "host_conv_bwd_weight.hpp"
#include "device_tensor.hpp"
#include "device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp"
#include "device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp"
#include "device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_atomic_nchw_kcyx_nkhw.hpp"
#include "device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk.hpp"
#include "device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_kyxc_nhwk.hpp"
#define USE_DYNAMIC_MODE 1
#define USE_CONV_WRW_V4R4R2_XDL_NCHW 1
#define USE_CONV_WRW_V4R4R2_XDL_NCHW 0
#define USE_CONV_WRW_V4R4R4_XDL_NHWC 0
#define USE_CONV_WRW_V4R4R2_XDL_ATOMIC_NCHW 0
#define USE_CONV_WRW_V4R4R4_XDL_ATOMIC_NHWC 0
#define USE_CONV_WRW_V4R4R5_XDL_ATOMIC_NHWC 1
enum ConvBackwardWeightAlgo
{
V4R4R2XDLNCHW,
V4R4R2XDLNCHW, // 0
V4R4R4XDLNHWC, // 1
V4R4R2XDLATOMICNCHW, // 2
V4R4R4XDLATOMICNHWC, // 3
V4R4R5XDLATOMICNHWC, // 4
};
int main(int argc, char* argv[])
......@@ -36,10 +49,11 @@ int main(int argc, char* argv[])
#if USE_DYNAMIC_MODE
// dynamic mode
if(argc != 22)
if(argc != 23)
{
printf("arg1 to 6: layout, algo, do_verification, init_method, do_log, nrepeat\n");
printf("rest: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, RightPx\n");
printf("additional: desired_grid_size\n");
exit(1);
}
......@@ -67,6 +81,8 @@ int main(int argc, char* argv[])
const index_t in_right_pad_h = std::stoi(argv[20]);
const index_t in_right_pad_w = std::stoi(argv[21]);
const index_t desired_grid_size = std::stoi(argv[22]);
const index_t YEff = (Y - 1) * conv_dilation_h + 1;
const index_t XEff = (X - 1) * conv_dilation_w + 1;
......@@ -111,18 +127,21 @@ int main(int argc, char* argv[])
constexpr auto Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + I1;
#endif
#if 1
#if 0
using in_data_t = float;
using wei_data_t = float;
using acc_data_t = float;
using out_data_t = float;
#elif 1
using in_data_t = half_t;
using acc_data_t = float;
using out_data_t = half_t;
using acc_data_t = float;
using wei_data_t = float;
#elif 1
using in_data_t = int8_t;
using acc_data_t = int32_t;
using out_data_t = int8_t;
using acc_data_t = int32_t;
using wei_data_t = int8_t;
#endif
std::vector<std::size_t> in_lengths_host(4), wei_lengths_host(4), out_lengths_host(4);
......@@ -163,8 +182,8 @@ int main(int argc, char* argv[])
}
Tensor<in_data_t> in(in_lengths_host);
Tensor<in_data_t> wei_device(wei_lengths_host);
Tensor<out_data_t> wei_host(wei_lengths_host);
Tensor<wei_data_t> wei_device(wei_lengths_host);
Tensor<wei_data_t> wei_host(wei_lengths_host);
Tensor<out_data_t> out(out_lengths_host);
std::cout << "layout: " << layout << std::endl;
......@@ -230,6 +249,26 @@ int main(int argc, char* argv[])
in_right_pads_dev);
};
auto f_make_for_device_nhwc = [&]() {
const auto in_lengths_dev = make_tuple(N, Hi, Wi, C);
const auto wei_lengths_dev = make_tuple(K, Y, X, C);
const auto out_lengths_dev = make_tuple(N, Ho, Wo, K);
const auto conv_strides_dev = make_tuple(conv_stride_h, conv_stride_w);
const auto conv_dilations_dev = make_tuple(conv_dilation_h, conv_dilation_w);
const auto in_left_pads_dev = make_tuple(in_left_pad_h, in_left_pad_w);
const auto in_right_pads_dev = make_tuple(in_right_pad_h, in_right_pad_w);
return make_tuple(in_lengths_dev,
wei_lengths_dev,
out_lengths_dev,
conv_strides_dev,
conv_dilations_dev,
in_left_pads_dev,
in_right_pads_dev);
};
// set zero to wei_device
wei_device.GenerateTensorValue(GeneratorTensor_0{}, num_thread);
#if USE_CONV_WRW_V4R4R2_XDL_NCHW
if(algo == ConvBackwardWeightAlgo::V4R4R2XDLNCHW)
{
......@@ -241,6 +280,35 @@ int main(int argc, char* argv[])
const auto tmp = f_make_for_device_nchw();
device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw<in_data_t,
wei_data_t,
acc_data_t,
out_data_t>(
tmp[I0],
tmp[I1],
tmp[I2],
tmp[I3],
tmp[I4],
tmp[I5],
tmp[I6],
in,
wei_device,
out,
nrepeat);
}
#endif
#if USE_CONV_WRW_V4R4R4_XDL_NHWC
if(algo == ConvBackwardWeightAlgo::V4R4R4XDLNHWC)
{
if(layout != ConvTensorLayout::NHWC)
{
throw std::runtime_error("wrong! layout");
}
const auto tmp = f_make_for_device_nhwc();
device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk<in_data_t,
wei_data_t,
acc_data_t,
out_data_t>(
tmp[I0],
......@@ -257,6 +325,93 @@ int main(int argc, char* argv[])
}
#endif
#if USE_CONV_WRW_V4R4R2_XDL_ATOMIC_NCHW
if(algo == ConvBackwardWeightAlgo::V4R4R2XDLATOMICNCHW)
{
if(layout != ConvTensorLayout::NCHW)
{
throw std::runtime_error("wrong! layout");
}
const auto tmp = f_make_for_device_nchw();
device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_atomic_nchw_kcyx_nkhw<
in_data_t,
wei_data_t,
acc_data_t,
out_data_t>(tmp[I0],
tmp[I1],
tmp[I2],
tmp[I3],
tmp[I4],
tmp[I5],
tmp[I6],
in,
wei_device,
out,
desired_grid_size,
nrepeat);
}
#endif
#if USE_CONV_WRW_V4R4R4_XDL_ATOMIC_NHWC
if(algo == ConvBackwardWeightAlgo::V4R4R4XDLATOMICNHWC)
{
if(layout != ConvTensorLayout::NHWC)
{
throw std::runtime_error("wrong! layout");
}
const auto tmp = f_make_for_device_nhwc();
device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk<
in_data_t,
wei_data_t,
acc_data_t,
out_data_t>(tmp[I0],
tmp[I1],
tmp[I2],
tmp[I3],
tmp[I4],
tmp[I5],
tmp[I6],
in,
wei_device,
out,
desired_grid_size,
nrepeat);
}
#endif
#if USE_CONV_WRW_V4R4R5_XDL_ATOMIC_NHWC
if(algo == ConvBackwardWeightAlgo::V4R4R5XDLATOMICNHWC)
{
if(layout != ConvTensorLayout::NHWC)
{
throw std::runtime_error("wrong! layout");
}
const auto tmp = f_make_for_device_nhwc();
device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_kyxc_nhwk<
in_data_t,
wei_data_t,
acc_data_t,
out_data_t>(tmp[I0],
tmp[I1],
tmp[I2],
tmp[I3],
tmp[I4],
tmp[I5],
tmp[I6],
in,
wei_device,
out,
desired_grid_size,
nrepeat);
}
#endif
if(do_verification)
{
host_direct_convolution_backward_weights(out,
......
......@@ -5,6 +5,7 @@
#include <stdlib.h>
#include <half.hpp>
#include "config.hpp"
#include "debug.hpp"
#include "print.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
......@@ -16,11 +17,19 @@
#include "device_gemm_xdlops_mk_nk_mn.hpp"
#include "device_gemm_xdlops_km_kn_mn.hpp"
#include "device_gemm_xdlops_km_nk_mn.hpp"
#include "device_gemm_xdlops_mk_kn_nm.hpp"
#include "device_gemm_xdlops_mk_nk_nm.hpp"
#include "device_gemm_xdlops_km_kn_nm.hpp"
#include "device_gemm_xdlops_km_nk_nm.hpp"
#define USE_GEMM_XDL_MK_KN_MN 1
#define USE_GEMM_XDL_MK_NK_MN 1
#define USE_GEMM_XDL_KM_KN_MN 1
#define USE_GEMM_XDL_KM_NK_MN 1
#define USE_GEMM_XDL_MK_KN_NM 0
#define USE_GEMM_XDL_MK_NK_NM 0
#define USE_GEMM_XDL_KM_KN_NM 0
#define USE_GEMM_XDL_KM_NK_NM 0
enum GemmAlgo
{
......@@ -28,21 +37,21 @@ enum GemmAlgo
Xdl_MK_NK_MN, // 1
Xdl_KM_KN_MN, // 2
Xdl_KM_NK_MN, // 3
Xdl_MK_KN_NM, // 4
Xdl_MK_NK_NM, // 5
Xdl_KM_KN_NM, // 6
Xdl_KM_NK_NM, // 7
};
int main(int argc, char* argv[])
{
using namespace ck;
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
// dynamic mode
if(argc != 10)
if(argc != 12)
{
printf("arg1 to 6: layout, algo, do_verification, init_method, do_log, nrepeat\n");
printf("rest: M, N, K\n");
printf("debug_driver_gemm_xdlops_v2r3::M01, debug_driver_gemm_xdlops_v2r3::N01\n");
exit(1);
}
......@@ -57,6 +66,9 @@ int main(int argc, char* argv[])
const index_t N = std::stoi(argv[8]);
const index_t K = std::stoi(argv[9]);
debug::debug_driver_gemm_xdlops_v2r3::M01 = std::stoi(argv[10]);
debug::debug_driver_gemm_xdlops_v2r3::N01 = std::stoi(argv[11]);
#if 0
using ab_data_t = float;
using acc_data_t = float;
......@@ -74,69 +86,44 @@ int main(int argc, char* argv[])
std::vector<std::size_t> a_lengths_host(2), b_lengths_host(2), c_lengths_host(2);
std::vector<std::size_t> a_strides_host(2), b_strides_host(2), c_strides_host(2);
if(layout == GemmMatrixLayout::MK_KN_MN)
// A
if(layout == GemmMatrixLayout::MK_KN_MN || layout == GemmMatrixLayout::MK_NK_MN ||
layout == GemmMatrixLayout::MK_KN_NM || layout == GemmMatrixLayout::MK_NK_NM)
{
a_lengths_host[0] = static_cast<std::size_t>(M);
a_lengths_host[1] = static_cast<std::size_t>(K);
a_strides_host[0] = static_cast<std::size_t>(K);
a_strides_host[1] = static_cast<std::size_t>(1);
b_lengths_host[0] = static_cast<std::size_t>(K);
b_lengths_host[1] = static_cast<std::size_t>(N);
b_strides_host[0] = static_cast<std::size_t>(N);
b_strides_host[1] = static_cast<std::size_t>(1);
c_lengths_host[0] = static_cast<std::size_t>(M);
c_lengths_host[1] = static_cast<std::size_t>(N);
c_strides_host[0] = static_cast<std::size_t>(N);
c_strides_host[1] = static_cast<std::size_t>(1);
}
else if(layout == GemmMatrixLayout::MK_NK_MN)
else
{
a_lengths_host[0] = static_cast<std::size_t>(M);
a_lengths_host[1] = static_cast<std::size_t>(K);
a_strides_host[0] = static_cast<std::size_t>(K);
a_lengths_host[0] = static_cast<std::size_t>(K);
a_lengths_host[1] = static_cast<std::size_t>(M);
a_strides_host[0] = static_cast<std::size_t>(M);
a_strides_host[1] = static_cast<std::size_t>(1);
}
// B
if(layout == GemmMatrixLayout::MK_NK_MN || layout == GemmMatrixLayout::KM_NK_MN ||
layout == GemmMatrixLayout::MK_NK_NM || layout == GemmMatrixLayout::KM_NK_NM)
{
b_lengths_host[0] = static_cast<std::size_t>(N);
b_lengths_host[1] = static_cast<std::size_t>(K);
b_strides_host[0] = static_cast<std::size_t>(K);
b_strides_host[1] = static_cast<std::size_t>(1);
c_lengths_host[0] = static_cast<std::size_t>(M);
c_lengths_host[1] = static_cast<std::size_t>(N);
c_strides_host[0] = static_cast<std::size_t>(N);
c_strides_host[1] = static_cast<std::size_t>(1);
}
else if(layout == GemmMatrixLayout::KM_KN_MN)
else
{
a_lengths_host[0] = static_cast<std::size_t>(K);
a_lengths_host[1] = static_cast<std::size_t>(M);
a_strides_host[0] = static_cast<std::size_t>(M);
a_strides_host[1] = static_cast<std::size_t>(1);
b_lengths_host[0] = static_cast<std::size_t>(K);
b_lengths_host[1] = static_cast<std::size_t>(N);
b_strides_host[0] = static_cast<std::size_t>(N);
b_strides_host[1] = static_cast<std::size_t>(1);
c_lengths_host[0] = static_cast<std::size_t>(M);
c_lengths_host[1] = static_cast<std::size_t>(N);
c_strides_host[0] = static_cast<std::size_t>(N);
c_strides_host[1] = static_cast<std::size_t>(1);
}
else if(layout == GemmMatrixLayout::KM_NK_MN)
{
a_lengths_host[0] = static_cast<std::size_t>(K);
a_lengths_host[1] = static_cast<std::size_t>(M);
a_strides_host[0] = static_cast<std::size_t>(M);
a_strides_host[1] = static_cast<std::size_t>(1);
b_lengths_host[0] = static_cast<std::size_t>(N);
b_lengths_host[1] = static_cast<std::size_t>(K);
b_strides_host[0] = static_cast<std::size_t>(K);
b_strides_host[1] = static_cast<std::size_t>(1);
// C
if(layout == GemmMatrixLayout::MK_KN_MN || layout == GemmMatrixLayout::KM_KN_MN ||
layout == GemmMatrixLayout::MK_NK_MN || layout == GemmMatrixLayout::KM_NK_MN)
{
c_lengths_host[0] = static_cast<std::size_t>(M);
c_lengths_host[1] = static_cast<std::size_t>(N);
c_strides_host[0] = static_cast<std::size_t>(N);
......@@ -144,7 +131,10 @@ int main(int argc, char* argv[])
}
else
{
std::runtime_error("wrong! not implemented");
c_lengths_host[0] = static_cast<std::size_t>(N);
c_lengths_host[1] = static_cast<std::size_t>(M);
c_strides_host[0] = static_cast<std::size_t>(M);
c_strides_host[1] = static_cast<std::size_t>(1);
}
Tensor<ab_data_t> a(a_lengths_host, a_strides_host);
......@@ -185,38 +175,6 @@ int main(int argc, char* argv[])
b.GenerateTensorValue(GeneratorTensor_3<float>{-0.5, 0.5}, num_thread);
}
auto f_make_for_device_mk_kn_mn = [&]() {
const auto a_desc = make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(K, I1));
const auto b_desc = make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(N, I1));
const auto c_desc = make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(N, I1));
return make_tuple(a_desc, b_desc, c_desc);
};
auto f_make_for_device_mk_nk_mn = [&]() {
const auto a_desc = make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(K, I1));
const auto b_desc = make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(K, I1));
const auto c_desc = make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(N, I1));
return make_tuple(a_desc, b_desc, c_desc);
};
auto f_make_for_device_km_kn_mn = [&]() {
const auto a_desc = make_naive_tensor_descriptor(make_tuple(K, M), make_tuple(M, I1));
const auto b_desc = make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(N, I1));
const auto c_desc = make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(N, I1));
return make_tuple(a_desc, b_desc, c_desc);
};
auto f_make_for_device_km_nk_mn = [&]() {
const auto a_desc = make_naive_tensor_descriptor(make_tuple(K, M), make_tuple(M, I1));
const auto b_desc = make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(K, I1));
const auto c_desc = make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(N, I1));
return make_tuple(a_desc, b_desc, c_desc);
};
#if USE_GEMM_XDL_MK_KN_MN
if(algo == GemmAlgo::Xdl_MK_KN_MN)
{
......@@ -225,10 +183,7 @@ int main(int argc, char* argv[])
throw std::runtime_error("wrong! layout");
}
const auto descs = f_make_for_device_mk_kn_mn();
device_gemm_xdlops_mk_kn_mn<ab_data_t, acc_data_t, c_data_t>(
descs[I0], descs[I1], descs[I2], a, b, c_device, nrepeat);
device_gemm_xdlops_mk_kn_mn<ab_data_t, acc_data_t, c_data_t>(a, b, c_device, nrepeat);
}
#endif
......@@ -240,10 +195,7 @@ int main(int argc, char* argv[])
throw std::runtime_error("wrong! layout");
}
const auto descs = f_make_for_device_mk_nk_mn();
device_gemm_xdlops_mk_nk_mn<ab_data_t, acc_data_t, c_data_t>(
descs[I0], descs[I1], descs[I2], a, b, c_device, nrepeat);
device_gemm_xdlops_mk_nk_mn<ab_data_t, acc_data_t, c_data_t>(a, b, c_device, nrepeat);
}
#endif
......@@ -255,10 +207,7 @@ int main(int argc, char* argv[])
throw std::runtime_error("wrong! layout");
}
const auto descs = f_make_for_device_km_kn_mn();
device_gemm_xdlops_km_kn_mn<ab_data_t, acc_data_t, c_data_t>(
descs[I0], descs[I1], descs[I2], a, b, c_device, nrepeat);
device_gemm_xdlops_km_kn_mn<ab_data_t, acc_data_t, c_data_t>(a, b, c_device, nrepeat);
}
#endif
......@@ -270,10 +219,55 @@ int main(int argc, char* argv[])
throw std::runtime_error("wrong! layout");
}
const auto descs = f_make_for_device_km_nk_mn();
device_gemm_xdlops_km_nk_mn<ab_data_t, acc_data_t, c_data_t>(a, b, c_device, nrepeat);
}
#endif
#if USE_GEMM_XDL_MK_KN_NM
if(algo == GemmAlgo::Xdl_MK_KN_NM)
{
if(layout != GemmMatrixLayout::MK_KN_NM)
{
throw std::runtime_error("wrong! layout");
}
device_gemm_xdlops_mk_kn_nm<ab_data_t, acc_data_t, c_data_t>(a, b, c_device, nrepeat);
}
#endif
#if USE_GEMM_XDL_MK_NK_NM
if(algo == GemmAlgo::Xdl_MK_NK_NM)
{
if(layout != GemmMatrixLayout::MK_NK_NM)
{
throw std::runtime_error("wrong! layout");
}
device_gemm_xdlops_mk_nk_nm<ab_data_t, acc_data_t, c_data_t>(a, b, c_device, nrepeat);
}
#endif
#if USE_GEMM_XDL_KM_KN_NM
if(algo == GemmAlgo::Xdl_KM_KN_NM)
{
if(layout != GemmMatrixLayout::KM_KN_NM)
{
throw std::runtime_error("wrong! layout");
}
device_gemm_xdlops_km_kn_nm<ab_data_t, acc_data_t, c_data_t>(a, b, c_device, nrepeat);
}
#endif
#if USE_GEMM_XDL_KM_NK_NM
if(algo == GemmAlgo::Xdl_KM_NK_NM)
{
if(layout != GemmMatrixLayout::KM_NK_NM)
{
throw std::runtime_error("wrong! layout");
}
device_gemm_xdlops_km_nk_mn<ab_data_t, acc_data_t, c_data_t>(
descs[I0], descs[I1], descs[I2], a, b, c_device, nrepeat);
device_gemm_xdlops_km_nk_nm<ab_data_t, acc_data_t, c_data_t>(a, b, c_device, nrepeat);
}
#endif
......
......@@ -2,6 +2,9 @@
#define DEVICE_HPP
#include <memory>
#include <functional>
#include <thread>
#include <chrono>
#include "hip/hip_runtime.h"
#include "hip/hip_fp16.h"
......@@ -74,7 +77,8 @@ float launch_and_time_kernel(
timer.End();
// std::this_thread::sleep_for (std::chrono::microseconds(10));
return timer.GetElapsedTime() / nrepeat;
}
#endif
......@@ -7,6 +7,10 @@ enum GemmMatrixLayout
MK_NK_MN, // 1
KM_KN_MN, // 2
KM_NK_MN, // 3
MK_KN_NM, // 4
MK_NK_NM, // 5
KM_KN_NM, // 6
KM_NK_NM, // 7
};
#endif
......@@ -80,6 +80,78 @@ void host_gemm(const Tensor<AType>& a,
make_ParallelTensorFunctor(f_km_nk_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
std::thread::hardware_concurrency());
}
else if(layout == GemmMatrixLayout::MK_KN_NM)
{
auto f_mk_kn_nm = [&](auto n, auto m) {
const int K = a.mDesc.GetLengths()[1];
double v = 0;
for(int k = 0; k < K; ++k)
{
v += static_cast<const double>(a(m, k)) * static_cast<const double>(b(k, n));
}
c(n, m) = v;
};
make_ParallelTensorFunctor(f_mk_kn_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
std::thread::hardware_concurrency());
}
else if(layout == GemmMatrixLayout::MK_NK_NM)
{
auto f_mk_nk_nm = [&](auto n, auto m) {
const int K = a.mDesc.GetLengths()[1];
double v = 0;
for(int k = 0; k < K; ++k)
{
v += static_cast<const double>(a(m, k)) * static_cast<const double>(b(n, k));
}
c(n, m) = v;
};
make_ParallelTensorFunctor(f_mk_nk_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
std::thread::hardware_concurrency());
}
else if(layout == GemmMatrixLayout::KM_KN_NM)
{
auto f_km_kn_nm = [&](auto n, auto m) {
const int K = a.mDesc.GetLengths()[0];
double v = 0;
for(int k = 0; k < K; ++k)
{
v += static_cast<const double>(a(k, m)) * static_cast<const double>(b(k, n));
}
c(n, m) = v;
};
make_ParallelTensorFunctor(f_km_kn_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
std::thread::hardware_concurrency());
}
else if(layout == GemmMatrixLayout::KM_NK_NM)
{
auto f_km_nk_nm = [&](auto n, auto m) {
const int K = a.mDesc.GetLengths()[0];
double v = 0;
for(int k = 0; k < K; ++k)
{
v += static_cast<const double>(a(k, m)) * static_cast<const double>(b(n, k));
}
c(n, m) = v;
};
make_ParallelTensorFunctor(f_km_nk_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
std::thread::hardware_concurrency());
}
else
{
throw std::runtime_error("wrong! not supported layout");
......
......@@ -15,6 +15,17 @@ struct GeneratorTensor_1
}
};
struct GeneratorTensor_0
{
int value = 0;
template <typename... Is>
float operator()(Is...)
{
return value;
}
};
struct GeneratorTensor_2
{
int min_value = 0;
......
WORKSPACE=$1
echo "workspace: " $WORKSPACE
docker run \
-it \
--rm \
--privileged \
--group-add sudo \
-w /root/workspace \
-v $WORKSPACE:/root/workspace \
rocm/tensorflow:rocm4.3.1-tf2.6-dev \
/bin/bash
#--network host \
......@@ -4,24 +4,12 @@
export ROCR_VISIBLE_DEVICE=0
export GPU_DEVICE_ORDINAL=0
## Boost
export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH
## Compiling
#export OLC_DEBUG_HIP_VERBOSE=1
#export OLC_DEBUG_HIP_DUMP=1
#export OLC_DEBUG_SAVE_TEMP_DIR=1
#rm -rf /root/_hip_binary_kernels_/
#rm -rf /tmp/olCompile*
#make -j conv_fwd_driver_offline
make -j conv_fwd_driver_offline
#make -j conv_bwd_driver_offline
#make -j conv_wrw_driver_offline
#make -j conv_fwd_driver_online
make -j gemm_driver_offline
#make -j gemm_driver_offline
DRIVER="./host/driver_offline/conv_fwd_driver_offline"
LAYOUT=$1
ALGO=$2
VERIFY=$3
......@@ -29,30 +17,121 @@ INIT=$4
LOG=$5
REPEAT=$6
################################################ layout algo verify init log repeat N__ K___ C___ Y X Hi_ Wi__ Strides Dilations LeftPads RightPads
#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 192 3 3 71 71 2 2 1 1 1 1 1 1
#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1
#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 1024 1 7 17 17 1 1 1 1 0 3 0 3
#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 14 14 1 1 1 1 1 1 1 1
#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 128 3 3 14 14 1 1 1 1 1 1 1 1
#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 512 3 3 7 7 1 1 1 1 1 1 1 1
#M01=$7
#N01=$8
KBATCH=$7
######### layout algo verify init log repeat N__ K___ C___ Y X Hi_ Wi__ Strides Dilations LeftPads RightPads
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 192 3 3 71 71 2 2 1 1 1 1 1 1
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 1024 1 7 17 17 1 1 1 1 0 3 0 3
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 14 14 1 1 1 1 1 1 1 1
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 128 3 3 14 14 1 1 1 1 1 1 1 1
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 512 3 3 7 7 1 1 1 1 1 1 1 1
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 192 3 3 35 35 2 2 1 1 0 0 0 0
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 30 30 2 2 1 1 0 0 0 0
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 512 3 3 16 16 2 2 1 1 0 0 0 0
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 2048 1024 1 1 14 14 2 2 1 1 0 0 0 0
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 1024 1 1 14 14 1 1 1 1 0 0 0 0
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 2048 1 1 7 7 1 1 1 1 0 0 0 0
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 14 14 1 1 1 1 1 1 1 1
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 128 3 3 14 14 1 1 1 1 1 1 1 1
######### layout algo verify init log repeat M___ N___ K___
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 960 1024 1024 $M01 $N01
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 1920 2048 2048 $M01 $N01
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 3840 4096 4096 $M01 $N01
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 7680 8192 8192 $M01 $N01
# Resnet50
######### layout algo verify init log repeat N__ K___ C___ Y X Hi_ Wi__ Strides Dilations LeftPads RightPads
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 2048 1024 1 1 14 14 2 2 1 1 0 0 0 0
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 1024 1 1 14 14 1 1 1 1 0 0 0 0
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 1024 1 1 14 14 1 1 1 1 0 0 0 0
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 128 128 3 3 28 28 1 1 1 1 1 1 1 1
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 128 1 1 28 28 1 1 1 1 0 0 0 0
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 128 128 3 3 58 58 2 2 1 1 0 0 0 0
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 2048 1 1 7 7 1 1 1 1 0 0 0 0
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 1024 256 1 1 14 14 1 1 1 1 0 0 0 0
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 14 14 1 1 1 1 1 1 1 1
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 30 30 2 2 1 1 0 0 0 0
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 128 256 1 1 56 56 1 1 1 1 0 0 0 0
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 256 1 1 56 56 2 2 1 1 0 0 0 0
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 64 256 1 1 56 56 1 1 1 1 0 0 0 0
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 512 3 3 16 16 2 2 1 1 0 0 0 0
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 1024 512 1 1 28 28 2 2 1 1 0 0 0 0
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 128 512 1 1 28 28 1 1 1 1 0 0 0 0
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 512 1 1 28 28 1 1 1 1 0 0 0 0
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 2048 512 1 1 7 7 1 1 1 1 0 0 0 0
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 512 3 3 7 7 1 1 1 1 1 1 1 1
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 64 1 1 56 56 1 1 1 1 0 0 0 0
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 64 64 1 1 56 56 1 1 1 1 0 0 0 0
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 64 64 3 3 56 56 1 1 1 1 1 1 1 1
# 256x128x32 c64
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 2048 1024 1 1 14 14 2 2 1 1 0 0 0 0 7
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 1024 1 1 14 14 1 1 1 1 0 0 0 0 56
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 1024 1 1 14 14 1 1 1 1 0 0 0 0 56
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 128 3 3 28 28 1 1 1 1 1 1 1 1 $KBATCH
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 128 1 1 28 28 1 1 1 1 0 0 0 0 224
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 128 3 3 58 58 2 2 1 1 0 0 0 0 $KBATCH
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 2048 1 1 7 7 1 1 1 1 0 0 0 0 14
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 1024 256 1 1 14 14 1 1 1 1 0 0 0 0 56
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 256 3 3 14 14 1 1 1 1 1 1 1 1 28
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 256 3 3 30 30 2 2 1 1 0 0 0 0 28
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 256 1 1 56 56 1 1 1 1 0 0 0 0 $KBATCH
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 256 1 1 56 56 2 2 1 1 0 0 0 0 224
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 256 1 1 56 56 1 1 1 1 0 0 0 0 $KBATCH
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 512 3 3 16 16 2 2 1 1 0 0 0 0 7
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 1024 512 1 1 28 28 2 2 1 1 0 0 0 0 56
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 512 1 1 28 28 1 1 1 1 0 0 0 0 $KBATCH
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 512 1 1 28 28 1 1 1 1 0 0 0 0 224
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 2048 512 1 1 7 7 1 1 1 1 0 0 0 0 14
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 512 3 3 7 7 1 1 1 1 1 1 1 1 7
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 64 1 1 56 56 1 1 1 1 0 0 0 0 $KBATCH
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 64 1 1 56 56 1 1 1 1 0 0 0 0 $KBATCH
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 64 3 3 56 56 1 1 1 1 1 1 1 1 $KBATCH
#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 192 3 3 35 35 2 2 1 1 0 0 0 0
#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 30 30 2 2 1 1 0 0 0 0
#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 512 3 3 16 16 2 2 1 1 0 0 0 0
#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 2048 1024 1 1 14 14 2 2 1 1 0 0 0 0
#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 1024 1 1 14 14 1 1 1 1 0 0 0 0
#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 2048 1 1 7 7 1 1 1 1 0 0 0 0
# 128x128x32 c64
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 2048 1024 1 1 14 14 2 2 1 1 0 0 0 0 7
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 1024 1 1 14 14 1 1 1 1 0 0 0 0 56
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 1024 1 1 14 14 1 1 1 1 0 0 0 0 28
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 128 3 3 28 28 1 1 1 1 1 1 1 1 112
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 128 1 1 28 28 1 1 1 1 0 0 0 0 224
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 128 3 3 58 58 2 2 1 1 0 0 0 0 112
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 2048 1 1 7 7 1 1 1 1 0 0 0 0 14
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 1024 256 1 1 14 14 1 1 1 1 0 0 0 0 56
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 256 3 3 14 14 1 1 1 1 1 1 1 1 28
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 256 3 3 30 30 2 2 1 1 0 0 0 0 28
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 256 1 1 56 56 1 1 1 1 0 0 0 0 448
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 256 1 1 56 56 2 2 1 1 0 0 0 0 224
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 256 1 1 56 56 1 1 1 1 0 0 0 0 $KBATCH
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 512 3 3 16 16 2 2 1 1 0 0 0 0 7
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 1024 512 1 1 28 28 2 2 1 1 0 0 0 0 28
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 512 1 1 28 28 1 1 1 1 0 0 0 0 224
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 512 1 1 28 28 1 1 1 1 0 0 0 0 112
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 2048 512 1 1 7 7 1 1 1 1 0 0 0 0 14
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 512 3 3 7 7 1 1 1 1 1 1 1 1 7
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 64 1 1 56 56 1 1 1 1 0 0 0 0 $KBATCH
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 64 1 1 56 56 1 1 1 1 0 0 0 0 $KBATCH
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 64 3 3 56 56 1 1 1 1 1 1 1 1 $KBATCH
#./host/driver_offline/conv_bwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 14 14 1 1 1 1 1 1 1 1
#./host/driver_offline/conv_wrw_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 128 3 3 14 14 1 1 1 1 1 1 1 1
# 128x64x32 c64
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 64 1 1 56 56 1 1 1 1 0 0 0 0 112
#./host/driver_online/conv_fwd_driver_online $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 192 3 3 71 71 2 2 1 1 1 1 1 1
# 64x128x32 c64
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 256 1 1 56 56 1 1 1 1 0 0 0 0 $KBATCH
################################################ layout algo verify init log repeat M___ N___ K___
#./host/driver_offline/gemm_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 960 1024 1024
#./host/driver_offline/gemm_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 1920 2048 2048
./host/driver_offline/gemm_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 3840 4096 4096
#./host/driver_offline/gemm_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 7680 8192 8192
# 64x64x32 c32
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 256 1 1 56 56 1 1 1 1 0 0 0 0 112
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 64 1 1 56 56 1 1 1 1 0 0 0 0 112
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 64 1 1 56 56 1 1 1 1 0 0 0 0 448
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 64 3 3 56 56 1 1 1 1 1 1 1 1 448
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment