"git@developer.sourcefind.cn:wqshmzh/ktransformers.git" did not exist on "8acb270c90d0f400e9538db2e665dc61e8907a2f"
Commit e3a09b57 authored by rocking's avatar rocking
Browse files

Add gridwise_elementwise_2d api

parent 6818b58c
......@@ -125,7 +125,7 @@ struct Sub
};
using DeviceElementwiseInstance =
ck::tensor_operation::device::DeviceElementwise_2D<CDataType, CDataType, CDataType, 256, Sub>;
ck::tensor_operation::device::DeviceElementwise_2D<CDataType, CDataType, CDataType, Sub, 16, 16, 8, 8>;
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, PassThrough, PassThrough, PassThrough>;
......
......@@ -4,6 +4,7 @@
#include "device.hpp"
#include "device_elementwise.hpp"
#include "gridwise_elementwise_2d.hpp"
namespace ck {
namespace tensor_operation {
......@@ -12,17 +13,36 @@ namespace device {
template <typename ADataType,
typename BDataType,
typename CDataType,
index_t BlockSize,
typename ElementwiseFunctor>
typename ElementwiseFunctor,
index_t MThreadPerBlock,
index_t NThreadPerBlock,
index_t MThreadTileSize,
index_t NThreadTileSize>
struct DeviceElementwise_2D : public DeviceElementwise<ElementwiseFunctor>
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static auto Make2dDescriptor_M_N(const std::vector<int>& shape, const std::vector<int>& stride)
{
return make_naive_tensor_descriptor(make_tuple(shape[0], shape[1]),
make_tuple(stride[0], stride[1]));
}
using GridDesc_M_N = decltype(Make2dDescriptor_M_N({1, 1}, {1, 1}));
static constexpr index_t BlockSize = MThreadPerBlock * NThreadPerBlock;
static constexpr int M_BlockTileSize = MThreadPerBlock * MThreadTileSize;
static constexpr int N_BlockTileSize = NThreadPerBlock * NThreadTileSize;
using GridDesc_M_N = decltype(Make2dDescriptor_M_N({1, 1}, {1, 1}));
using GridwiseEltwise = GridwiseElementwise_2D<ADataType,
BDataType,
CDataType,
GridDesc_M_N,
GridDesc_M_N,
GridDesc_M_N,
ElementwiseFunctor,
MThreadTileSize,
NThreadTileSize>;
struct Argument : public BaseArgument
{
......@@ -55,12 +75,63 @@ struct DeviceElementwise_2D : public DeviceElementwise<ElementwiseFunctor>
struct Invoker : public BaseInvoker
{
index_t CalculateGridSize(const GridDesc_M_N& grid_desc_m_n)
{
const auto M = grid_desc_m_n.GetLength(I0);
const auto N = grid_desc_m_n.GetLength(I1);
assert(M % M_BlockTileSize == 0);
assert(N % N_BlockTileSize == 0);
return (M / M_BlockTileSize) * (N / N_BlockTileSize);
}
float Run(const Argument& arg, int nrepeat = 1)
{
const auto kernel = kernel_elementwise_2d<GridwiseEltwise,
ADataType,
BDataType,
CDataType,
GridDesc_M_N,
GridDesc_M_N,
GridDesc_M_N,
ElementwiseFunctor>;
// TODO
(void)arg;
(void)nrepeat;
return 0;
(void)kernel;
float avgTime = 0;
const index_t gridSize = CalculateGridSize(arg.c_grid_desc_m_n_);
if(nrepeat == 0)
{
launch_kernel(kernel,
dim3(gridSize),
dim3(BlockSize),
0,
arg.p_a_,
arg.p_b_,
arg.p_c_,
arg.a_grid_desc_m_n_,
arg.b_grid_desc_m_n_,
arg.c_grid_desc_m_n_,
arg.functor_);
}
else
{
avgTime = launch_and_time_kernel(kernel,
nrepeat,
dim3(gridSize),
dim3(BlockSize),
0,
arg.p_a_,
arg.p_b_,
arg.p_c_,
arg.a_grid_desc_m_n_,
arg.b_grid_desc_m_n_,
arg.c_grid_desc_m_n_,
arg.functor_);
}
return avgTime;
}
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
......@@ -71,9 +142,18 @@ struct DeviceElementwise_2D : public DeviceElementwise<ElementwiseFunctor>
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
// TODO: properly implement this check
const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
return pArg != nullptr;
if(pArg == nullptr)
return false;
const auto M = pArg->c_grid_desc_m_n_.GetLength(I0);
const auto N = pArg->c_grid_desc_m_n_.GetLength(I1);
if(M % M_BlockTileSize != 0 && N % N_BlockTileSize != 0)
return false;
return true;
};
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
......@@ -107,7 +187,6 @@ struct DeviceElementwise_2D : public DeviceElementwise<ElementwiseFunctor>
// clang-format off
str << "DeviceElementwise_2D"
<< "<"
<< BlockSize
<< ">";
// clang-format on
......
#pragma once
#include "cluster_descriptor.hpp"
#include "data_type.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
namespace ck {
template <typename GridwiseEltwise,
typename ADataType,
typename BDataType,
typename CDataType,
typename AGridDesc_M_N,
typename BGridDesc_M_N,
typename CGridDesc_M_N,
typename ElementwiseFunctor>
__global__ void kernel_elementwise_2d(const ADataType* __restrict__ p_a_global,
const BDataType* __restrict__ p_b_global,
CDataType* __restrict__ p_c_global,
const AGridDesc_M_N a_grid_desc_m_k,
const BGridDesc_M_N b_grid_desc_m_k,
const CGridDesc_M_N c_grid_desc_m_k,
const ElementwiseFunctor functor)
{
GridwiseEltwise::Run(p_a_global,
p_b_global,
p_c_global,
a_grid_desc_m_k,
b_grid_desc_m_k,
c_grid_desc_m_k,
functor);
}
template <typename ADataType,
typename BDataType,
typename CDataType,
typename AGridDesc_M_N,
typename BGridDesc_M_N,
typename CGridDesc_M_N,
typename ElementwiseFunctor,
index_t MThreadTileSize,
index_t NThreadTileSize>
struct GridwiseElementwise_2D
{
__device__ static void Run(const ADataType* __restrict__ p_a_global,
const BDataType* __restrict__ p_b_global,
CDataType* __restrict__ p_c_global,
const AGridDesc_M_N a_grid_desc_m_n,
const BGridDesc_M_N b_grid_desc_m_n,
const CGridDesc_M_N c_grid_desc_m_n,
const ElementwiseFunctor functor)
{
// const index_t thread_id = get_thread_local_1d_id();
// const index_t block_id = get_block_1d_id();
// printf("block_id = %d, thread_id = %d \n", block_id, thread_id);
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_global, a_grid_desc_m_n.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_global, b_grid_desc_m_n.GetElementSpaceSize());
const auto c_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_global, c_grid_desc_m_n.GetElementSpaceSize());
StaticBuffer<AddressSpaceEnum::Vgpr, ADataType, MThreadTileSize * NThreadTileSize, true>
a_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, BDataType, MThreadTileSize * NThreadTileSize, true>
b_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, CDataType, MThreadTileSize * NThreadTileSize, true>
c_thread_buf;
// TODO - buffer_load, apply functor, buffer_store
(void)a_global_buf;
(void)b_global_buf;
(void)c_global_buf;
(void)a_thread_buf;
(void)b_thread_buf;
(void)c_thread_buf;
(void)functor;
}
};
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment