Commit 2faeaece authored by j4yan's avatar j4yan
Browse files

navi_gemm_km_kn_mn_fp32 compiles and passes one test.

parent fd7eee0d
......@@ -3,7 +3,7 @@
#include "common_header.hpp"
#include "tensor_adaptor.hpp"
#include "threadwise_tensor_slice_transfer_v2.hpp"
#include "threadwise_tensor_slice_transfer_v4r1.hpp"
#include "threadwise_contraction_dlops.hpp"
namespace ck {
......
......@@ -33,7 +33,8 @@ template <
index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t K0PerBlock,
index_t K1,
index_t M1PerThread,
index_t N1PerThread,
index_t KPerThread,
......@@ -56,17 +57,13 @@ template <
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector,
typename AGridStepHacks,
typename BGridStepHacks,
typename CGridStepHacks,
typename AGridMoveSliceWindowStepHacks,
typename BGridMoveSliceWindowStepHacks,
enable_if_t<
is_same_v<AElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> &&
is_same_v<AElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> &&
is_same_v<AElementwiseOperation, ck::tensor_operation::element_wise::PassThrough>,
bool> = false>
struct DeviceGemmDlops : public DeviceGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>
struct DeviceGemmDlops
: public DeviceGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
......@@ -201,12 +198,12 @@ struct DeviceGemmDlops : public DeviceGemm<AElementwiseOperation, BElementwiseOp
AccDataType,
CDataType,
InMemoryDataOperationEnum::Set,
AK0MK1GridDesc,
BK0NK1GridDesc,
CMNGridDesc,
AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1,
CGridDesc_M_N,
MPerBlock,
NPerBlock,
KPerBlock,
K0PerBlock,
M1PerThread,
N1PerThread,
KPerThread,
......@@ -228,18 +225,16 @@ struct DeviceGemmDlops : public DeviceGemm<AElementwiseOperation, BElementwiseOp
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
AGridStepHacks,
BGridStepHacks,
CGridStepHacks,
AGridMoveSliceWindowStepHacks,
BGridMoveSliceWindowStepHacks>;
using AK0M0M1K1GridDesc =
CThreadTransferDstScalarPerVector>;
using AGridDesc_K0_M0_M1_K1 =
decltype(GridwiseGemm::MakeAK0M0M1K1GridDescriptor(AGridDesc_K0_M_K1{}));
using BK0N0N1K1GridDesc = decltype(GridwiseGemm::MakeBKN0N1GridDescriptor(BGridDesc_K0_N_K1{}));
using CM0M10M11N0N10N11GridDesc =
using BGridDesc_K0_N0_N1_K1 =
decltype(GridwiseGemm::MakeBK0N0N1K1GridDescriptor(BGridDesc_K0_N_K1{}));
using CGridDesc_M0_M10_M11_N0_N10_N11 =
decltype(GridwiseGemm::MakeCM0M10M11N0N10N11GridDescriptor(CGridDesc_M_N{}));
using DefaultBlock2CTileMap =
decltype(GridwiseGemm::MakeCBlockIdToM0N0BlockClusterAdaptor(CGridDesc_M_N{}));
// Argument
struct Argument : public BaseArgument
......@@ -261,10 +256,9 @@ struct DeviceGemmDlops : public DeviceGemm<AElementwiseOperation, BElementwiseOp
: p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid},
p_c_grid_{p_c_grid},
a_grid_desc_k0_m_k1_{},
b_grid_desc_k0_n_k1_{},
c_grid_desc_m_n_{},
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{},
a_grid_desc_k0_m0_m1_k1_{},
b_grid_desc_k0_n0_n1_k1_{},
c_grid_desc_m0_m10_m11_n0_n10_n11_{},
block_2_ctile_map_{},
M01_{M01},
N01_{N01}
......@@ -272,15 +266,19 @@ struct DeviceGemmDlops : public DeviceGemm<AElementwiseOperation, BElementwiseOp
// b_element_op_{b_element_op},
// c_element_op_{c_element_op}
{
a_grid_desc_k0_m_k1_ = DeviceGemmXdl::MakeAGridDescriptor_K0_M_K1(M, K, StrideA);
b_grid_desc_k0_n_k1_ = DeviceGemmXdl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB);
c_grid_desc_m_n_ = DeviceGemmXdl::MakeCGridDescriptor_M_N(M, N, StrideC);
a_grid_desc_k0_m_k1_ = DeviceGemmDlops::MakeAGridDescriptor_K0_M_K1(M, K, StrideA);
b_grid_desc_k0_n_k1_ = DeviceGemmDlops::MakeBGridDescriptor_K0_N_K1(K, N, StrideB);
c_grid_desc_m_n_ = DeviceGemmDlops::MakeCGridDescriptor_M_N(M, N, StrideC);
if(GridwiseGemm::CheckValidity(
a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_))
a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_))
{
c_m0_m10_m11_n0_n10_n11_grid_desc =
GridwiseGemm::MakeCM0M10M11N0N10N11GridDescriptor(c_m_n_grid_desc);
a_grid_desc_k0_m0_m1_k1_ =
GridwiseGemm::MakeAK0M0M1K1GridDescriptor(a_grid_desc_k0_m_k1_);
b_grid_desc_k0_n0_n1_k1_ =
GridwiseGemm::MakeBK0N0N1K1GridDescriptor(b_grid_desc_k0_n_k1_);
c_grid_desc_m0_m10_m11_n0_n10_n11_ =
GridwiseGemm::MakeCM0M10M11N0N10N11GridDescriptor(c_grid_desc_m_n_);
block_2_ctile_map_ =
GridwiseGemm::MakeCBlockIdToM0N0BlockClusterAdaptor(c_grid_desc_m_n_);
......@@ -292,11 +290,15 @@ struct DeviceGemmDlops : public DeviceGemm<AElementwiseOperation, BElementwiseOp
const BDataType* p_b_grid_;
CDataType* p_c_grid_;
AK0M0M1K1GridDesc a_k0_m0_m1_k1_grid_desc;
BK0N0N1K1GridDesc b_k0_n0_n1_k1_grid_desc;
CM0M10M11N0N10N11GridDesc c_m0_m10_m11_n0_n10_n11_grid_desc;
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
CGridDesc_M_N c_grid_desc_m_n_;
AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1_;
BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1_;
CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11_;
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
DefaultBlock2CTileMap block_2_ctile_map_;
index_t M01_;
index_t N01_;
......@@ -309,36 +311,35 @@ struct DeviceGemmDlops : public DeviceGemm<AElementwiseOperation, BElementwiseOp
// Invoker
struct Invoker : public BaseInvoker
{
using Argument = DeviceGemmXdl::Argument;
using Argument = DeviceGemmDlops::Argument;
float Run(const Argument& arg, int nrepeat = 1)
{
{
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
<< arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.a_grid_desc_k0_m0_m1_k1_{"
<< arg.a_grid_desc_k0_m0_m1_k1_.GetLength(I0) << ", "
<< arg.a_grid_desc_k0_m0_m1_k1_.GetLength(I1) << ", "
<< arg.a_grid_desc_k0_m0_m1_k1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0)
<< ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
<< arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.b_grid_desc_k0_n0_n1_k1_{"
<< arg.b_grid_desc_k0_n0_n1_k1_.GetLength(I0) << ", "
<< arg.b_grid_desc_k0_n0_n1_k1_.GetLength(I1) << ", "
<< arg.b_grid_desc_k0_n0_n1_k1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
}
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_,
arg.M01_,
arg.N01_))
if(!GridwiseGemm::CheckValidity(
arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_))
{
throw std::runtime_error(
"wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting");
}
const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_);
const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_.GetLength(I0), arg.c_grid_desc_m_n_.GetLength(I1));
const auto K0 = arg.a_grid_desc_k0_m_k1_.GetLength(I0);
const auto K0 = arg.a_grid_desc_k0_m0_m1_k1_.GetLength(I0);
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K0);
const bool has_double_tail_k_block_loop =
GridwiseGemm::CalculateHasDoubleTailKBlockLoop(K0);
......@@ -351,10 +352,10 @@ struct DeviceGemmDlops : public DeviceGemm<AElementwiseOperation, BElementwiseOp
kernel_gemm_dlops_v1r3<GridwiseGemm,
ADataType,
CDataType,
remove_reference_t<AK0M0M1K1GridDesc>,
remove_reference_t<BK0N0N1K1GridDesc>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
remove_reference_t<AGridDesc_K0_M0_M1_K1>,
remove_reference_t<BGridDesc_K0_N0_N1_K1>,
remove_reference_t<CGridDesc_M0_M10_M11_N0_N10_N11>,
remove_reference_t<DefaultBlock2CTileMap>,
true,
true>;
......@@ -369,7 +370,7 @@ struct DeviceGemmDlops : public DeviceGemm<AElementwiseOperation, BElementwiseOp
arg.a_grid_desc_k0_m0_m1_k1_,
arg.b_grid_desc_k0_n0_n1_k1_,
arg.c_grid_desc_m0_m10_m11_n0_n10_n11_,
arg.cblockid_to_m0_n0_block_cluster_adaptor_);
arg.block_2_ctile_map_);
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
......@@ -377,10 +378,10 @@ struct DeviceGemmDlops : public DeviceGemm<AElementwiseOperation, BElementwiseOp
kernel_gemm_dlops_v1r3<GridwiseGemm,
ADataType,
CDataType,
remove_reference_t<AK0M0M1K1GridDesc>,
remove_reference_t<BK0N0N1K1GridDesc>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
remove_reference_t<AGridDesc_K0_M0_M1_K1>,
remove_reference_t<BGridDesc_K0_N0_N1_K1>,
remove_reference_t<CGridDesc_M0_M10_M11_N0_N10_N11>,
remove_reference_t<DefaultBlock2CTileMap>,
true,
false>;
......@@ -395,7 +396,7 @@ struct DeviceGemmDlops : public DeviceGemm<AElementwiseOperation, BElementwiseOp
arg.a_grid_desc_k0_m0_m1_k1_,
arg.b_grid_desc_k0_n0_n1_k1_,
arg.c_grid_desc_m0_m10_m11_n0_n10_n11_,
arg.cblockid_to_m0_n0_block_cluster_adaptor_);
arg.block_2_ctile_map_);
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
......@@ -403,10 +404,10 @@ struct DeviceGemmDlops : public DeviceGemm<AElementwiseOperation, BElementwiseOp
kernel_gemm_dlops_v1r3<GridwiseGemm,
ADataType,
CDataType,
remove_reference_t<AK0M0M1K1GridDesc>,
remove_reference_t<BK0N0N1K1GridDesc>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
remove_reference_t<AGridDesc_K0_M0_M1_K1>,
remove_reference_t<BGridDesc_K0_N0_N1_K1>,
remove_reference_t<CGridDesc_M0_M10_M11_N0_N10_N11>,
remove_reference_t<DefaultBlock2CTileMap>,
false,
true>;
......@@ -421,7 +422,7 @@ struct DeviceGemmDlops : public DeviceGemm<AElementwiseOperation, BElementwiseOp
arg.a_grid_desc_k0_m0_m1_k1_,
arg.b_grid_desc_k0_n0_n1_k1_,
arg.c_grid_desc_m0_m10_m11_n0_n10_n11_,
arg.cblockid_to_m0_n0_block_cluster_adaptor_);
arg.block_2_ctile_map_);
}
else
{
......@@ -429,10 +430,10 @@ struct DeviceGemmDlops : public DeviceGemm<AElementwiseOperation, BElementwiseOp
kernel_gemm_dlops_v1r3<GridwiseGemm,
ADataType,
CDataType,
remove_reference_t<AK0M0M1K1GridDesc>,
remove_reference_t<BK0N0N1K1GridDesc>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
remove_reference_t<AGridDesc_K0_M0_M1_K1>,
remove_reference_t<BGridDesc_K0_N0_N1_K1>,
remove_reference_t<CGridDesc_M0_M10_M11_N0_N10_N11>,
remove_reference_t<DefaultBlock2CTileMap>,
false,
false>;
......@@ -447,7 +448,7 @@ struct DeviceGemmDlops : public DeviceGemm<AElementwiseOperation, BElementwiseOp
arg.a_grid_desc_k0_m0_m1_k1_,
arg.b_grid_desc_k0_n0_n1_k1_,
arg.c_grid_desc_m0_m10_m11_n0_n10_n11_,
arg.cblockid_to_m0_n0_block_cluster_adaptor_);
arg.block_2_ctile_map_);
}
return ave_time;
......@@ -468,11 +469,8 @@ struct DeviceGemmDlops : public DeviceGemm<AElementwiseOperation, BElementwiseOp
static bool IsSupportedArgument(const Argument& arg)
{
return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_,
arg.M01_,
arg.N01_);
return GridwiseGemm::CheckValidity(
arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_);
}
// polymorphic
......@@ -555,17 +553,16 @@ struct DeviceGemmDlops : public DeviceGemm<AElementwiseOperation, BElementwiseOp
auto str = std::stringstream();
// clang-format off
str << "DeviceGemmXdl"
str << "DeviceGemmDlops"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
<< K0PerBlock << ", "
<< K1 << ", "
<< MPerXDL << ", "
<< NPerXDL << ", "
<< MXdlPerWave << ", "
<< NXdlPerWave
<< M1PerThread << ", "
<< N1PerThread << ", "
<< KPerThread
<< ">";
// clang-format on
......
......@@ -7,8 +7,9 @@
#include "tensor_descriptor_helper.hpp"
#include "blockwise_gemm_dlops_v2r3.hpp"
#include "blockwise_tensor_slice_transfer_v5r1.hpp"
#include "threadwise_tensor_slice_transfer_v2.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "threadwise_tensor_slice_set.hpp"
#include "element_wise_operation.hpp"
namespace ck {
......@@ -327,7 +328,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
ABlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(a_k0_m0_m1_k1_grid_desc),
remove_reference_t<decltype(a_k0_m0_m1_k1_grid_desc)>,
decltype(a_k0_m0_m1_k1_block_desc),
ABlockTransferSrcAccessOrder,
Sequence<0, 1, 2, 3>,
......@@ -351,7 +352,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
BBlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(b_k0_n0_n1_k1_grid_desc),
remove_reference_t<decltype(b_k0_n0_n1_k1_grid_desc)>,
decltype(b_k0_n0_n1_k1_block_desc),
BBlockTransferSrcAccessOrder,
Sequence<0, 1, 2, 3>,
......@@ -498,10 +499,8 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
// LDS double buffer: tail
if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
{
a_blockwise_copy.MoveSrcSliceWindow(
a_k0_m0_m1_k1_grid_desc, a_block_slice_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(
b_k0_n0_n1_k1_grid_desc, b_block_slice_copy_step);
a_blockwise_copy.MoveSrcSliceWindow(a_k0_m0_m1_k1_grid_desc, a_block_slice_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_k0_n0_n1_k1_grid_desc, b_block_slice_copy_step);
__syncthreads();
......@@ -552,6 +551,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
FloatC,
decltype(c_m0_m10_m11_n0_n10_n11_thread_desc),
decltype(c_m0_m10_m11_n0_n10_n11_grid_desc),
ck::tensor_operation::element_wise::PassThrough,
Sequence<1,
c_m10_m11_n10_n11_thread_tensor_lengths[I0],
c_m10_m11_n10_n11_thread_tensor_lengths[I1],
......@@ -569,7 +569,8 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
c_m10_m11_n10_n11_thread_origin_idx_on_block[I1],
in0,
c_m10_m11_n10_n11_thread_origin_idx_on_block[I2],
c_m10_m11_n10_n11_thread_origin_idx_on_block[I3])}
c_m10_m11_n10_n11_thread_origin_idx_on_block[I3]),
ck::tensor_operation::element_wise::PassThrough{}}
.Run(c_m0_m10_m11_n0_n10_n11_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0),
c_thread_buf,
......
#ifndef CK_THREADWISE_TENSOR_SLICE_TRANSFER_V5R1_HPP
#define CK_THREADWISE_TENSOR_SLICE_TRANSFER_V5R1_HPP
#pragma once
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
......@@ -609,4 +608,3 @@ struct ThreadwiseTensorSliceTransfer_v5r1
};
} // namespace ck
#endif
......@@ -45,3 +45,12 @@ set_target_properties(device_gemm_instance PROPERTIES POSITION_INDEPENDENT_CODE
install(TARGETS device_gemm_instance LIBRARY DESTINATION lib)
clang_tidy_check(device_gemm_instance)
add_library(device_gemm_dlops_instance SHARED device_gemm_dlops_f32_f32_f32_km_kn_mn_instance.cpp)
target_compile_features(device_gemm_dlops_instance PUBLIC)
set_target_properties(device_gemm_dlops_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
install(TARGETS device_gemm_dlops_instance LIBRARY DESTINATION lib)
clang_tidy_check(device_gemm_dlops_instance)
......@@ -23,16 +23,52 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// Compilation parameters for a[k, m] * b[k, n] = c[m, n]
using device_gemm_dlops_f32_f32_f32_km_kn_mn_instances =
std::tuple<
// clang-format off
using device_gemm_dlops_f32_f32_f32_km_kn_mn_instances = std::tuple<
// clang-format off
// ##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer|
// ##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess|
// ##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order|
// ##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0 ,3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>
// clang-format on
>;
// DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0 ,3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>
// clang-format on
DeviceGemmDlops<F32,
F32,
F32,
F32,
Col,
Row,
Row,
PassThrough,
PassThrough,
PassThrough,
GemmDefault,
256,
128,
128,
8,
2,
4,
4,
1,
S<8, 2>,
S<8, 2>,
S<4, 1, 1, 2>,
S<2, 1, 128, 1>,
S<1, 2, 0, 3>,
S<1, 2, 0, 3>,
S<4, 1, 1, 2>,
S<1, 2, 0, 3>,
S<1, 1, 1, 2>,
S<4, 1, 1, 2>,
S<2, 1, 128, 1>,
S<1, 2, 0, 3>,
S<1, 2, 0, 3>,
S<4, 1, 1, 2>,
S<1, 2, 0, 3>,
S<1, 1, 1, 2>,
S<0, 1, 2, 3, 4, 5>,
5,
4>>;
void add_device_gemm_dlops_f32_f32_f32_km_kn_mn_instances(
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
......
......@@ -37,6 +37,7 @@ add_subdirectory(space_filling_curve)
add_subdirectory(conv_util)
add_subdirectory(reference_conv_fwd)
add_subdirectory(gemm)
add_subdirectory(gemm_dlops)
add_subdirectory(gemm_split_k)
add_subdirectory(gemm_reduce)
add_subdirectory(batched_gemm)
......
add_test_executable(test_gemm_dlops_fp32 gemm_fp32.cpp)
add_test_executable(test_gemm_dlops_fp32 gemm_dlops_fp32.cpp)
target_link_libraries(test_gemm_dlops_fp32 PRIVATE host_tensor)
target_link_libraries(test_gemm_dlops_fp32 PRIVATE device_gemm_dlops_instance)
......
......@@ -6,7 +6,7 @@
#include <tuple>
#include <vector>
#include "gemm_util.hpp"
#include "../gemm/gemm_util.hpp"
#include "config.hpp"
#include "print.hpp"
#include "device.hpp"
......@@ -15,7 +15,6 @@
#include "host_gemm.hpp"
#include "device_tensor.hpp"
#include "device_gemm_xdl.hpp"
#include "device_gemm_dlops_c_shuffle.hpp"
#include "element_wise_operation.hpp"
#include "reference_gemm.hpp"
#include "gemm_specialization.hpp"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment