Commit 3eee1b9b authored by Harisankar Sadasivan's avatar Harisankar Sadasivan
Browse files

adding tall and skinny gemm

parent 67adf1b4
......@@ -10,9 +10,9 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemv.hpp"
#include "ck/tensor_operation/gpu/device/device_tall_and_skinny_gemm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemv_splitk.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_tall_and_skinny_gemm_splitk.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
......@@ -58,7 +58,7 @@ template <
is_same_v<BElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> &&
is_same_v<CElementwiseOperation, ck::tensor_operation::element_wise::PassThrough>,
bool> = false>
struct deviceGemvDl : public DeviceGemv<ALayout,
struct deviceTsmmDl : public DeviceTsmm<ALayout,
BLayout,
CLayout,
ADataType,
......@@ -76,9 +76,9 @@ struct deviceGemvDl : public DeviceGemv<ALayout,
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
// GridwiseGemv
using GridwiseGemv =
GridwiseGemvDl_km_kn_mn<BlockSize,
// GridwiseTsmm
using GridwiseTsmm =
GridwiseTsmmDl_km_kn_mn<BlockSize,
ADataType,
AccDataType,
CDataType,
......@@ -107,8 +107,8 @@ struct deviceGemvDl : public DeviceGemv<ALayout,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector>;
using DefaultBlock2CTileMap = typename GridwiseGemv::DefaultBlock2CTileMap;
using Argument = typename GridwiseGemv::Argument;
using DefaultBlock2CTileMap = typename GridwiseTsmm::DefaultBlock2CTileMap;
using Argument = typename GridwiseTsmm::Argument;
// Invoker
struct Invoker : public BaseInvoker
{
......@@ -116,14 +116,14 @@ struct deviceGemvDl : public DeviceGemv<ALayout,
float Run(const Argument& karg, const StreamConfig& stream_config = StreamConfig{})
{
const index_t grid_size = GridwiseGemv::CalculateGridSize(karg.M, karg.N, karg.k_batch);
const index_t grid_size = GridwiseTsmm::CalculateGridSize(karg.M, karg.N, karg.k_batch);
const auto b2c_map = DefaultBlock2CTileMap{};
const auto K0 = karg.K0;
const bool has_main_k_block_loop = GridwiseGemv::CalculateHasMainKBlockLoop(K0);
const bool has_main_k_block_loop = GridwiseTsmm::CalculateHasMainKBlockLoop(K0);
const bool has_double_tail_k_block_loop =
GridwiseGemv::CalculateHasDoubleTailKBlockLoop(K0);
GridwiseTsmm::CalculateHasDoubleTailKBlockLoop(K0);
float ave_time = 0;
......@@ -134,7 +134,7 @@ struct deviceGemvDl : public DeviceGemv<ALayout,
{
if(karg.k_batch == 1)
{
const auto kernel = kernel_gemv_dl_v1r3<GridwiseGemv,
const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm,
ADataType,
CDataType,
InMemoryDataOperationEnum::Set,
......@@ -146,7 +146,7 @@ struct deviceGemvDl : public DeviceGemv<ALayout,
}
else
{
const auto kernel = kernel_gemv_dl_v1r3<GridwiseGemv,
const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm,
ADataType,
CDataType,
InMemoryDataOperationEnum::AtomicAdd,
......@@ -162,7 +162,7 @@ struct deviceGemvDl : public DeviceGemv<ALayout,
if(karg.k_batch == 1)
{
const auto kernel = kernel_gemv_dl_v1r3<GridwiseGemv,
const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm,
ADataType,
CDataType,
InMemoryDataOperationEnum::Set,
......@@ -174,7 +174,7 @@ struct deviceGemvDl : public DeviceGemv<ALayout,
}
else
{
const auto kernel = kernel_gemv_dl_v1r3<GridwiseGemv,
const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm,
ADataType,
CDataType,
InMemoryDataOperationEnum::AtomicAdd,
......@@ -189,7 +189,7 @@ struct deviceGemvDl : public DeviceGemv<ALayout,
{
if(karg.k_batch == 1)
{
const auto kernel = kernel_gemv_dl_v1r3<GridwiseGemv,
const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm,
ADataType,
CDataType,
InMemoryDataOperationEnum::Set,
......@@ -201,7 +201,7 @@ struct deviceGemvDl : public DeviceGemv<ALayout,
}
else
{
const auto kernel = kernel_gemv_dl_v1r3<GridwiseGemv,
const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm,
ADataType,
CDataType,
InMemoryDataOperationEnum::AtomicAdd,
......@@ -216,7 +216,7 @@ struct deviceGemvDl : public DeviceGemv<ALayout,
{
if(karg.k_batch == 1)
{
const auto kernel = kernel_gemv_dl_v1r3<GridwiseGemv,
const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm,
ADataType,
CDataType,
InMemoryDataOperationEnum::Set,
......@@ -228,7 +228,7 @@ struct deviceGemvDl : public DeviceGemv<ALayout,
}
else
{
const auto kernel = kernel_gemv_dl_v1r3<GridwiseGemv,
const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm,
ADataType,
CDataType,
InMemoryDataOperationEnum::AtomicAdd,
......@@ -264,7 +264,7 @@ struct deviceGemvDl : public DeviceGemv<ALayout,
ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" ||
ck::get_device_name() == "gfx1102")
{
return GridwiseGemv::CheckValidity(arg);
return GridwiseTsmm::CheckValidity(arg);
}
else
{
......@@ -301,10 +301,10 @@ struct deviceGemvDl : public DeviceGemv<ALayout,
StrideA,
StrideB,
StrideC,
GridwiseGemv::CalculateMPadded(M),
GridwiseGemv::CalculateNPadded(N),
K,
GridwiseGemv::CalculateK0(K, KBatch),
GridwiseTsmm::CalculateMPadded(M),
GridwiseTsmm::CalculateNPadded(N),
GridwiseTsmm::CalculateKPadded(K, KBatch),
GridwiseTsmm::CalculateK0(K, KBatch),
KBatch}; // //
}
......@@ -325,6 +325,7 @@ struct deviceGemvDl : public DeviceGemv<ALayout,
CElementwiseOperation,
ck::index_t KBatch = 1) override // //
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
static_cast<CDataType*>(p_c),
......@@ -334,10 +335,10 @@ struct deviceGemvDl : public DeviceGemv<ALayout,
StrideA,
StrideB,
StrideC,
GridwiseGemv::CalculateMPadded(M),
GridwiseGemv::CalculateNPadded(N),
K,
GridwiseGemv::CalculateK0(K, KBatch),
GridwiseTsmm::CalculateMPadded(M),
GridwiseTsmm::CalculateNPadded(N),
GridwiseTsmm::CalculateKPadded(K, KBatch),
GridwiseTsmm::CalculateK0(K, KBatch),
KBatch); // //
}
......@@ -353,7 +354,7 @@ struct deviceGemvDl : public DeviceGemv<ALayout,
auto str = std::stringstream();
// clang-format off
str << "deviceGemvDl"
str << "deviceTsmmDl"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment