Commit dba65b1c authored by rocking's avatar rocking
Browse files

Rewrite the gridwise_elementwise_

2d as 1d version
parent 6a781e51
......@@ -165,10 +165,10 @@ struct Div
};
using DeviceElementwiseSubExpInstance = ck::tensor_operation::device::
DeviceElementwise_2D<CDataType, CDataType, CDataType, Sub_Exp, 16, 16, 8, 8, 1, 1, 1, 1, 1>;
DeviceElementwise_2D<CDataType, CDataType, CDataType, Sub_Exp, 256, 32, 8>;
using DeviceElementwiseDivInstance = ck::tensor_operation::device::
DeviceElementwise_2D<CDataType, CDataType, CDataType, Div, 16, 16, 8, 8, 1, 1, 1, 1, 1>;
DeviceElementwise_2D<CDataType, CDataType, CDataType, Div, 256, 32, 8>;
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, PassThrough, PassThrough, PassThrough>;
......
......@@ -4,7 +4,7 @@
#include "device.hpp"
#include "device_elementwise.hpp"
#include "gridwise_elementwise_2d.hpp"
#include "gridwise_elementwise_1d.hpp"
namespace ck {
namespace tensor_operation {
......@@ -14,48 +14,40 @@ template <typename ADataType,
typename BDataType,
typename CDataType,
typename ElementwiseFunctor,
index_t MThreadPerBlock,
index_t NThreadPerBlock,
index_t MThreadTileSize,
index_t NThreadTileSize,
index_t AThreadTransferSrcVectorDim,
index_t AThreadTransferSrcScalarPerVector,
index_t BThreadTransferSrcVectorDim,
index_t BThreadTransferSrcScalarPerVector,
index_t CThreadTransferSrcScalarPerVector>
index_t ThreadPerBlock,
index_t ThreadTileSize,
index_t ScalarPerVector>
struct DeviceElementwise_2D : public DeviceElementwise<ElementwiseFunctor>
{
static_assert(NThreadTileSize % AThreadTransferSrcScalarPerVector == 0 &&
NThreadTileSize % BThreadTransferSrcScalarPerVector == 0);
static_assert(ThreadTileSize % ScalarPerVector == 0);
static constexpr int BlockTileSize = ThreadPerBlock * ThreadTileSize;
static constexpr auto I0 = Number<0>{};
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static auto Make2dDescriptor_M_N(const std::vector<int>& shape, const std::vector<int>& stride)
static auto MakeDescriptor_M0(const std::vector<int>& shape, const std::vector<int>& stride)
{
return make_naive_tensor_descriptor(make_tuple(shape[0], shape[1]),
make_tuple(stride[0], stride[1]));
const int m = shape[0];
const int n = shape[1];
// 2d desc - [m, n]
const auto desc_m_n =
make_naive_tensor_descriptor(make_tuple(m, n), make_tuple(stride[0], stride[1]));
// 1d desc - [m * n]
return transform_tensor_descriptor(desc_m_n,
make_tuple(make_merge_transform(make_tuple(m, n))),
make_tuple(Sequence<0, 1>{}),
make_tuple(Sequence<0>{}));
}
static constexpr index_t BlockSize = MThreadPerBlock * NThreadPerBlock;
static constexpr int M_BlockTileSize = MThreadPerBlock * MThreadTileSize;
static constexpr int N_BlockTileSize = NThreadPerBlock * NThreadTileSize;
using GridDesc_M_N = decltype(Make2dDescriptor_M_N({1, 1}, {1, 1}));
using GridwiseEltwise = GridwiseElementwise_2D<ADataType,
using GridDesc_M0 = decltype(MakeDescriptor_M0({1, 1}, {1, 1}));
using GridwiseEltwise = GridwiseElementwise_1D<ADataType,
BDataType,
CDataType,
GridDesc_M_N,
GridDesc_M0,
ElementwiseFunctor,
MThreadPerBlock,
NThreadPerBlock,
MThreadTileSize,
NThreadTileSize,
AThreadTransferSrcVectorDim,
AThreadTransferSrcScalarPerVector,
BThreadTransferSrcVectorDim,
BThreadTransferSrcScalarPerVector,
CThreadTransferSrcScalarPerVector>;
ThreadPerBlock,
ThreadTileSize,
ScalarPerVector>;
struct Argument : public BaseArgument
{
......@@ -70,9 +62,9 @@ struct DeviceElementwise_2D : public DeviceElementwise<ElementwiseFunctor>
: p_a_(p_a),
p_b_(p_b),
p_c_(p_c),
a_grid_desc_m_n_(Make2dDescriptor_M_N(shape, stride_a)),
b_grid_desc_m_n_(Make2dDescriptor_M_N(shape, stride_b)),
c_grid_desc_m_n_(Make2dDescriptor_M_N(shape, stride_c)),
a_grid_desc_m0_(MakeDescriptor_M0(shape, stride_a)),
b_grid_desc_m0_(MakeDescriptor_M0(shape, stride_b)),
c_grid_desc_m0_(MakeDescriptor_M0(shape, stride_c)),
functor_(functor)
{
}
......@@ -80,47 +72,42 @@ struct DeviceElementwise_2D : public DeviceElementwise<ElementwiseFunctor>
const ADataType* p_a_;
const BDataType* p_b_;
CDataType* p_c_;
GridDesc_M_N a_grid_desc_m_n_;
GridDesc_M_N b_grid_desc_m_n_;
GridDesc_M_N c_grid_desc_m_n_;
GridDesc_M0 a_grid_desc_m0_;
GridDesc_M0 b_grid_desc_m0_;
GridDesc_M0 c_grid_desc_m0_;
ElementwiseFunctor functor_;
};
struct Invoker : public BaseInvoker
{
index_t CalculateGridSize(const GridDesc_M_N& grid_desc_m_n)
index_t CalculateGridSize(const GridDesc_M0& grid_desc_m0)
{
const auto M = grid_desc_m_n.GetLength(I0);
const auto N = grid_desc_m_n.GetLength(I1);
assert(M % M_BlockTileSize == 0);
assert(N % N_BlockTileSize == 0);
return (M / M_BlockTileSize) * (N / N_BlockTileSize);
const auto gridTileSize = grid_desc_m0.GetLength(I0);
return gridTileSize / BlockTileSize;
}
float Run(const Argument& arg, int nrepeat = 1)
{
const auto kernel = kernel_elementwise_2d<GridwiseEltwise,
const auto kernel = kernel_elementwise_1d<GridwiseEltwise,
ADataType,
BDataType,
CDataType,
GridDesc_M_N,
GridDesc_M0,
ElementwiseFunctor>;
float avgTime = 0;
const index_t gridSize = CalculateGridSize(arg.c_grid_desc_m_n_);
const index_t gridSize = CalculateGridSize(arg.c_grid_desc_m0_);
if(nrepeat == 0)
{
launch_kernel(kernel,
dim3(gridSize),
dim3(BlockSize),
dim3(ThreadPerBlock),
0,
arg.p_a_,
arg.p_b_,
arg.p_c_,
arg.a_grid_desc_m_n_,
arg.b_grid_desc_m_n_,
arg.c_grid_desc_m_n_,
arg.a_grid_desc_m0_,
arg.b_grid_desc_m0_,
arg.c_grid_desc_m0_,
arg.functor_);
}
else
......@@ -128,14 +115,14 @@ struct DeviceElementwise_2D : public DeviceElementwise<ElementwiseFunctor>
avgTime = launch_and_time_kernel(kernel,
nrepeat,
dim3(gridSize),
dim3(BlockSize),
dim3(ThreadPerBlock),
0,
arg.p_a_,
arg.p_b_,
arg.p_c_,
arg.a_grid_desc_m_n_,
arg.b_grid_desc_m_n_,
arg.c_grid_desc_m_n_,
arg.a_grid_desc_m0_,
arg.b_grid_desc_m0_,
arg.c_grid_desc_m0_,
arg.functor_);
}
return avgTime;
......@@ -154,10 +141,10 @@ struct DeviceElementwise_2D : public DeviceElementwise<ElementwiseFunctor>
if(pArg == nullptr)
return false;
const auto M = pArg->c_grid_desc_m_n_.GetLength(I0);
const auto N = pArg->c_grid_desc_m_n_.GetLength(I1);
// m * n
const auto m0 = pArg->c_grid_desc_m0_.GetLength(I0);
if(M % M_BlockTileSize != 0 && N % N_BlockTileSize != 0)
if(m0 % BlockTileSize != 0)
return false;
return true;
......
#pragma once
#include "cluster_descriptor.hpp"
#include "data_type.hpp"
#include "element_wise_operation.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
namespace ck {
template <typename GridwiseEltwise,
typename ADataType,
typename BDataType,
typename CDataType,
typename GridDesc_M0,
typename ElementwiseFunctor>
__global__ void kernel_elementwise_1d(const ADataType* __restrict__ p_a_global,
const BDataType* __restrict__ p_b_global,
CDataType* __restrict__ p_c_global,
const GridDesc_M0 a_grid_desc_m0,
const GridDesc_M0 b_grid_desc_m0,
const GridDesc_M0 c_grid_desc_m0,
const ElementwiseFunctor functor)
{
GridwiseEltwise::Run(p_a_global,
p_b_global,
p_c_global,
a_grid_desc_m0,
b_grid_desc_m0,
c_grid_desc_m0,
functor);
}
template <typename ADataType,
typename BDataType,
typename CDataType,
typename GridDesc_M0,
typename ElementwiseFunctor,
index_t ThreadPerBlock,
index_t ThreadTileSize,
index_t ScalarPerVector>
struct GridwiseElementwise_1D
{
static constexpr auto I0 = Number<0>{};
static constexpr int BlockTileSize = ThreadPerBlock * ThreadTileSize;
static constexpr auto thread_desc_M0 =
make_naive_tensor_descriptor_packed(make_tuple(Number<ScalarPerVector>{}));
using PassThrough = tensor_operation::element_wise::PassThrough;
static __device__ __host__ auto CalculateElementwiseIndex()
{
const index_t thread_id = get_thread_local_1d_id();
const index_t block_id = get_block_1d_id();
return make_multi_index(block_id * BlockTileSize + thread_id * ScalarPerVector);
}
__device__ static void Run(const ADataType* __restrict__ p_a_global,
const BDataType* __restrict__ p_b_global,
CDataType* __restrict__ p_c_global,
const GridDesc_M0 a_grid_desc_m0,
const GridDesc_M0 b_grid_desc_m0,
const GridDesc_M0 c_grid_desc_m0,
const ElementwiseFunctor functor)
{
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_global, a_grid_desc_m0.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_global, b_grid_desc_m0.GetElementSpaceSize());
auto c_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_global, c_grid_desc_m0.GetElementSpaceSize());
StaticBuffer<AddressSpaceEnum::Vgpr, ADataType, ScalarPerVector, true> a_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, BDataType, ScalarPerVector, true> b_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, CDataType, ScalarPerVector, true> c_thread_buf;
const auto thread_to_global_offset = CalculateElementwiseIndex();
auto a_global_load =
ThreadwiseTensorSliceTransfer_v2<ADataType,
ADataType,
GridDesc_M0,
decltype(thread_desc_M0),
Sequence<ScalarPerVector>, // SliceLengths
Sequence<0>, // DimAccessOrder
0, // SrcVectorDim
ScalarPerVector,
1, // SrcScalarStrideInVector
false>{a_grid_desc_m0, thread_to_global_offset};
auto b_global_load =
ThreadwiseTensorSliceTransfer_v2<BDataType,
BDataType,
GridDesc_M0,
decltype(thread_desc_M0),
Sequence<ScalarPerVector>, // SliceLengths
Sequence<0>, // DimAccessOrder
0, // SrcVectorDim
ScalarPerVector,
1, // SrcScalarStrideInVector
false>{b_grid_desc_m0, thread_to_global_offset};
auto c_global_write =
ThreadwiseTensorSliceTransfer_v1r3<CDataType,
CDataType,
decltype(thread_desc_M0),
GridDesc_M0,
PassThrough,
Sequence<ScalarPerVector>, // SliceLengths
Sequence<0>, // DimAccessOrder
0, // DstVectorDim
ScalarPerVector,
InMemoryDataOperationEnum::Set,
1, // DstScalarStrideInVector
false>{
c_grid_desc_m0, thread_to_global_offset, PassThrough{}};
int num_iter = ThreadTileSize / ScalarPerVector;
constexpr auto thread_to_global_step = make_multi_index(ThreadPerBlock * ScalarPerVector);
do
{
// read and process ScalarPerVector elements
a_global_load.Run(
a_grid_desc_m0, a_global_buf, thread_desc_M0, make_tuple(I0), a_thread_buf);
b_global_load.Run(
b_grid_desc_m0, b_global_buf, thread_desc_M0, make_tuple(I0), b_thread_buf);
static_for<0, ScalarPerVector, 1>{}([&](auto m) {
constexpr auto offset = thread_desc_M0.CalculateOffset(make_tuple(m));
functor(c_thread_buf(Number<offset>{}),
a_thread_buf(Number<offset>{}),
b_thread_buf(Number<offset>{}));
});
c_global_write.Run(thread_desc_M0,
make_tuple(I0), // SrcSliceOriginIdx
c_thread_buf,
c_grid_desc_m0,
c_global_buf);
a_global_load.MoveSrcSliceWindow(a_grid_desc_m0, thread_to_global_step);
b_global_load.MoveSrcSliceWindow(b_grid_desc_m0, thread_to_global_step);
c_global_write.MoveDstSliceWindow(c_grid_desc_m0, thread_to_global_step);
} while(--num_iter);
}
};
} // namespace ck
#pragma once
#include "cluster_descriptor.hpp"
#include "data_type.hpp"
#include "element_wise_operation.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
namespace ck {
template <typename GridwiseEltwise,
typename ADataType,
typename BDataType,
typename CDataType,
typename GridDesc_M_N,
typename ElementwiseFunctor>
__global__ void kernel_elementwise_2d(const ADataType* __restrict__ p_a_global,
const BDataType* __restrict__ p_b_global,
CDataType* __restrict__ p_c_global,
const GridDesc_M_N a_grid_desc_m_k,
const GridDesc_M_N b_grid_desc_m_k,
const GridDesc_M_N c_grid_desc_m_k,
const ElementwiseFunctor functor)
{
GridwiseEltwise::Run(p_a_global,
p_b_global,
p_c_global,
a_grid_desc_m_k,
b_grid_desc_m_k,
c_grid_desc_m_k,
functor);
}
template <typename ADataType,
typename BDataType,
typename CDataType,
typename GridDesc_M_N,
typename ElementwiseFunctor,
index_t MThreadPerBlock,
index_t NThreadPerBlock,
index_t MThreadTileSize,
index_t NThreadTileSize,
index_t AThreadTransferSrcVectorDim,
index_t AThreadTransferSrcScalarPerVector,
index_t BThreadTransferSrcVectorDim,
index_t BThreadTransferSrcScalarPerVector,
index_t CThreadTransferSrcScalarPerVector>
struct GridwiseElementwise_2D
{
static constexpr auto thread_buf_desc_M_N = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadTileSize>{}, Number<NThreadTileSize>{}));
using PassThrough = tensor_operation::element_wise::PassThrough;
using ThreadBufDesc_M_N = decltype(thread_buf_desc_M_N);
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr int M_BlockTileSize = MThreadPerBlock * MThreadTileSize;
static constexpr int N_BlockTileSize = NThreadPerBlock * NThreadTileSize;
static __device__ __host__ auto CalculateElementwiseIndex(const GridDesc_M_N& grid_desc_m_n)
{
const index_t thread_id = get_thread_local_1d_id();
const index_t block_id = get_block_1d_id();
const index_t M = grid_desc_m_n.GetLength(I0);
const index_t gridSize_m = M / M_BlockTileSize;
const index_t block_2d_idx_m = block_id % gridSize_m;
const index_t block_2d_idx_n = block_id / gridSize_m;
constexpr auto thread_desc =
make_cluster_descriptor(Sequence<MThreadPerBlock, NThreadPerBlock>{}, Sequence<1, 0>{});
const auto thread_2d_idx = thread_desc.CalculateBottomIndex(make_multi_index(thread_id));
return make_multi_index(
block_2d_idx_m * M_BlockTileSize + thread_2d_idx[I0] * MThreadTileSize,
block_2d_idx_n * N_BlockTileSize + thread_2d_idx[I1] * NThreadTileSize);
}
__device__ static void Run(const ADataType* __restrict__ p_a_global,
const BDataType* __restrict__ p_b_global,
CDataType* __restrict__ p_c_global,
const GridDesc_M_N a_grid_desc_m_n,
const GridDesc_M_N b_grid_desc_m_n,
const GridDesc_M_N c_grid_desc_m_n,
const ElementwiseFunctor functor)
{
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_global, a_grid_desc_m_n.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_global, b_grid_desc_m_n.GetElementSpaceSize());
auto c_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_global, c_grid_desc_m_n.GetElementSpaceSize());
StaticBuffer<AddressSpaceEnum::Vgpr, ADataType, MThreadTileSize * NThreadTileSize, true>
a_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, BDataType, MThreadTileSize * NThreadTileSize, true>
b_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, CDataType, MThreadTileSize * NThreadTileSize, true>
c_thread_buf;
const auto a_global_load_offset = CalculateElementwiseIndex(a_grid_desc_m_n);
const auto b_global_load_offset = CalculateElementwiseIndex(b_grid_desc_m_n);
auto a_global_load = ThreadwiseTensorSliceTransfer_v2<
ADataType,
ADataType,
GridDesc_M_N,
decltype(thread_buf_desc_M_N),
Sequence<MThreadTileSize, NThreadTileSize>, // SliceLengths
Sequence<0, 1>, // DimAccessOrder
AThreadTransferSrcVectorDim,
AThreadTransferSrcScalarPerVector,
1, // SrcScalarStrideInVector
false>{a_grid_desc_m_n, a_global_load_offset};
auto b_global_load = ThreadwiseTensorSliceTransfer_v2<
BDataType,
BDataType,
GridDesc_M_N,
decltype(thread_buf_desc_M_N),
Sequence<MThreadTileSize, NThreadTileSize>, // SliceLengths
Sequence<0, 1>, // DimAccessOrder
BThreadTransferSrcVectorDim,
BThreadTransferSrcScalarPerVector,
1, // SrcScalarStrideInVector
false>{b_grid_desc_m_n, b_global_load_offset};
a_global_load.Run(
a_grid_desc_m_n, a_global_buf, thread_buf_desc_M_N, make_tuple(I0, I0), a_thread_buf);
b_global_load.Run(
b_grid_desc_m_n, b_global_buf, thread_buf_desc_M_N, make_tuple(I0, I0), b_thread_buf);
static_for<0, MThreadTileSize, 1>{}([&](auto m) {
static_for<0, NThreadTileSize, 1>{}([&](auto n) {
constexpr auto offset = thread_buf_desc_M_N.CalculateOffset(make_tuple(m, n));
functor(c_thread_buf(Number<offset>{}),
a_thread_buf(Number<offset>{}),
b_thread_buf(Number<offset>{}));
});
});
// TODO - global write
const auto c_global_write_offset = CalculateElementwiseIndex(c_grid_desc_m_n);
auto c_global_write = ThreadwiseTensorSliceTransfer_v1r3<
CDataType,
CDataType,
decltype(thread_buf_desc_M_N),
GridDesc_M_N,
PassThrough,
Sequence<MThreadTileSize, NThreadTileSize>, // SliceLengths
Sequence<0, 1>, // DimAccessOrder
1, // DstVectorDim
CThreadTransferSrcScalarPerVector, // DstScalarPerVector
InMemoryDataOperationEnum::Set, // DstInMemOp
1, // DstScalarStrideInVector
false>{c_grid_desc_m_n, c_global_write_offset, PassThrough{}};
c_global_write.Run(
thread_buf_desc_M_N, make_tuple(I0, I0), c_thread_buf, c_grid_desc_m_n, c_global_buf);
}
};
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment