Commit 3eee1b9b authored by Harisankar Sadasivan's avatar Harisankar Sadasivan
Browse files

adding tall and skinny gemm

parent 67adf1b4
...@@ -10,9 +10,9 @@ ...@@ -10,9 +10,9 @@
#include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemv.hpp" #include "ck/tensor_operation/gpu/device/device_tall_and_skinny_gemm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemv_splitk.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_tall_and_skinny_gemm_splitk.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
...@@ -58,7 +58,7 @@ template < ...@@ -58,7 +58,7 @@ template <
is_same_v<BElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> && is_same_v<BElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> &&
is_same_v<CElementwiseOperation, ck::tensor_operation::element_wise::PassThrough>, is_same_v<CElementwiseOperation, ck::tensor_operation::element_wise::PassThrough>,
bool> = false> bool> = false>
struct deviceGemvDl : public DeviceGemv<ALayout, struct deviceTsmmDl : public DeviceTsmm<ALayout,
BLayout, BLayout,
CLayout, CLayout,
ADataType, ADataType,
...@@ -76,9 +76,9 @@ struct deviceGemvDl : public DeviceGemv<ALayout, ...@@ -76,9 +76,9 @@ struct deviceGemvDl : public DeviceGemv<ALayout,
static constexpr auto I4 = Number<4>{}; static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{}; static constexpr auto I5 = Number<5>{};
// GridwiseGemv // GridwiseTsmm
using GridwiseGemv = using GridwiseTsmm =
GridwiseGemvDl_km_kn_mn<BlockSize, GridwiseTsmmDl_km_kn_mn<BlockSize,
ADataType, ADataType,
AccDataType, AccDataType,
CDataType, CDataType,
...@@ -107,8 +107,8 @@ struct deviceGemvDl : public DeviceGemv<ALayout, ...@@ -107,8 +107,8 @@ struct deviceGemvDl : public DeviceGemv<ALayout,
CThreadTransferSrcDstVectorDim, CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector>; CThreadTransferDstScalarPerVector>;
using DefaultBlock2CTileMap = typename GridwiseGemv::DefaultBlock2CTileMap; using DefaultBlock2CTileMap = typename GridwiseTsmm::DefaultBlock2CTileMap;
using Argument = typename GridwiseGemv::Argument; using Argument = typename GridwiseTsmm::Argument;
// Invoker // Invoker
struct Invoker : public BaseInvoker struct Invoker : public BaseInvoker
{ {
...@@ -116,14 +116,14 @@ struct deviceGemvDl : public DeviceGemv<ALayout, ...@@ -116,14 +116,14 @@ struct deviceGemvDl : public DeviceGemv<ALayout,
float Run(const Argument& karg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& karg, const StreamConfig& stream_config = StreamConfig{})
{ {
const index_t grid_size = GridwiseGemv::CalculateGridSize(karg.M, karg.N, karg.k_batch); const index_t grid_size = GridwiseTsmm::CalculateGridSize(karg.M, karg.N, karg.k_batch);
const auto b2c_map = DefaultBlock2CTileMap{}; const auto b2c_map = DefaultBlock2CTileMap{};
const auto K0 = karg.K0; const auto K0 = karg.K0;
const bool has_main_k_block_loop = GridwiseGemv::CalculateHasMainKBlockLoop(K0); const bool has_main_k_block_loop = GridwiseTsmm::CalculateHasMainKBlockLoop(K0);
const bool has_double_tail_k_block_loop = const bool has_double_tail_k_block_loop =
GridwiseGemv::CalculateHasDoubleTailKBlockLoop(K0); GridwiseTsmm::CalculateHasDoubleTailKBlockLoop(K0);
float ave_time = 0; float ave_time = 0;
...@@ -134,7 +134,7 @@ struct deviceGemvDl : public DeviceGemv<ALayout, ...@@ -134,7 +134,7 @@ struct deviceGemvDl : public DeviceGemv<ALayout,
{ {
if(karg.k_batch == 1) if(karg.k_batch == 1)
{ {
const auto kernel = kernel_gemv_dl_v1r3<GridwiseGemv, const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm,
ADataType, ADataType,
CDataType, CDataType,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
...@@ -146,7 +146,7 @@ struct deviceGemvDl : public DeviceGemv<ALayout, ...@@ -146,7 +146,7 @@ struct deviceGemvDl : public DeviceGemv<ALayout,
} }
else else
{ {
const auto kernel = kernel_gemv_dl_v1r3<GridwiseGemv, const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm,
ADataType, ADataType,
CDataType, CDataType,
InMemoryDataOperationEnum::AtomicAdd, InMemoryDataOperationEnum::AtomicAdd,
...@@ -162,7 +162,7 @@ struct deviceGemvDl : public DeviceGemv<ALayout, ...@@ -162,7 +162,7 @@ struct deviceGemvDl : public DeviceGemv<ALayout,
if(karg.k_batch == 1) if(karg.k_batch == 1)
{ {
const auto kernel = kernel_gemv_dl_v1r3<GridwiseGemv, const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm,
ADataType, ADataType,
CDataType, CDataType,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
...@@ -174,7 +174,7 @@ struct deviceGemvDl : public DeviceGemv<ALayout, ...@@ -174,7 +174,7 @@ struct deviceGemvDl : public DeviceGemv<ALayout,
} }
else else
{ {
const auto kernel = kernel_gemv_dl_v1r3<GridwiseGemv, const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm,
ADataType, ADataType,
CDataType, CDataType,
InMemoryDataOperationEnum::AtomicAdd, InMemoryDataOperationEnum::AtomicAdd,
...@@ -189,7 +189,7 @@ struct deviceGemvDl : public DeviceGemv<ALayout, ...@@ -189,7 +189,7 @@ struct deviceGemvDl : public DeviceGemv<ALayout,
{ {
if(karg.k_batch == 1) if(karg.k_batch == 1)
{ {
const auto kernel = kernel_gemv_dl_v1r3<GridwiseGemv, const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm,
ADataType, ADataType,
CDataType, CDataType,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
...@@ -201,7 +201,7 @@ struct deviceGemvDl : public DeviceGemv<ALayout, ...@@ -201,7 +201,7 @@ struct deviceGemvDl : public DeviceGemv<ALayout,
} }
else else
{ {
const auto kernel = kernel_gemv_dl_v1r3<GridwiseGemv, const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm,
ADataType, ADataType,
CDataType, CDataType,
InMemoryDataOperationEnum::AtomicAdd, InMemoryDataOperationEnum::AtomicAdd,
...@@ -216,7 +216,7 @@ struct deviceGemvDl : public DeviceGemv<ALayout, ...@@ -216,7 +216,7 @@ struct deviceGemvDl : public DeviceGemv<ALayout,
{ {
if(karg.k_batch == 1) if(karg.k_batch == 1)
{ {
const auto kernel = kernel_gemv_dl_v1r3<GridwiseGemv, const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm,
ADataType, ADataType,
CDataType, CDataType,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
...@@ -228,7 +228,7 @@ struct deviceGemvDl : public DeviceGemv<ALayout, ...@@ -228,7 +228,7 @@ struct deviceGemvDl : public DeviceGemv<ALayout,
} }
else else
{ {
const auto kernel = kernel_gemv_dl_v1r3<GridwiseGemv, const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm,
ADataType, ADataType,
CDataType, CDataType,
InMemoryDataOperationEnum::AtomicAdd, InMemoryDataOperationEnum::AtomicAdd,
...@@ -264,7 +264,7 @@ struct deviceGemvDl : public DeviceGemv<ALayout, ...@@ -264,7 +264,7 @@ struct deviceGemvDl : public DeviceGemv<ALayout,
ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" || ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" ||
ck::get_device_name() == "gfx1102") ck::get_device_name() == "gfx1102")
{ {
return GridwiseGemv::CheckValidity(arg); return GridwiseTsmm::CheckValidity(arg);
} }
else else
{ {
...@@ -301,10 +301,10 @@ struct deviceGemvDl : public DeviceGemv<ALayout, ...@@ -301,10 +301,10 @@ struct deviceGemvDl : public DeviceGemv<ALayout,
StrideA, StrideA,
StrideB, StrideB,
StrideC, StrideC,
GridwiseGemv::CalculateMPadded(M), GridwiseTsmm::CalculateMPadded(M),
GridwiseGemv::CalculateNPadded(N), GridwiseTsmm::CalculateNPadded(N),
K, GridwiseTsmm::CalculateKPadded(K, KBatch),
GridwiseGemv::CalculateK0(K, KBatch), GridwiseTsmm::CalculateK0(K, KBatch),
KBatch}; // // KBatch}; // //
} }
...@@ -325,6 +325,7 @@ struct deviceGemvDl : public DeviceGemv<ALayout, ...@@ -325,6 +325,7 @@ struct deviceGemvDl : public DeviceGemv<ALayout,
CElementwiseOperation, CElementwiseOperation,
ck::index_t KBatch = 1) override // // ck::index_t KBatch = 1) override // //
{ {
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a), return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b), static_cast<const BDataType*>(p_b),
static_cast<CDataType*>(p_c), static_cast<CDataType*>(p_c),
...@@ -334,10 +335,10 @@ struct deviceGemvDl : public DeviceGemv<ALayout, ...@@ -334,10 +335,10 @@ struct deviceGemvDl : public DeviceGemv<ALayout,
StrideA, StrideA,
StrideB, StrideB,
StrideC, StrideC,
GridwiseGemv::CalculateMPadded(M), GridwiseTsmm::CalculateMPadded(M),
GridwiseGemv::CalculateNPadded(N), GridwiseTsmm::CalculateNPadded(N),
K, GridwiseTsmm::CalculateKPadded(K, KBatch),
GridwiseGemv::CalculateK0(K, KBatch), GridwiseTsmm::CalculateK0(K, KBatch),
KBatch); // // KBatch); // //
} }
...@@ -353,7 +354,7 @@ struct deviceGemvDl : public DeviceGemv<ALayout, ...@@ -353,7 +354,7 @@ struct deviceGemvDl : public DeviceGemv<ALayout,
auto str = std::stringstream(); auto str = std::stringstream();
// clang-format off // clang-format off
str << "deviceGemvDl" str << "deviceTsmmDl"
<< "<" << "<"
<< BlockSize << ", " << BlockSize << ", "
<< MPerBlock << ", " << MPerBlock << ", "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment