Commit a760a732 authored by rocking's avatar rocking
Browse files

A kernel of elementwise_2d (except global store)

parent cb1c4731
......@@ -84,8 +84,8 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle
8>; // CBlockTransferScalarPerVector_NWaveNPerXdl
// clang-format on
constexpr int Rank = 2;
constexpr int NumReduceDim = 1;
constexpr int Rank = 2;
constexpr int NumReduceDim = 1;
constexpr ck::ReduceTensorOp ReduceOpId = ck::ReduceTensorOp::MAX;
constexpr ck::NanPropagation NanOpt = ck::NanPropagation::PROPAGATE_NAN;
constexpr bool PropagateNan = (NanOpt == ck::NanPropagation::NOT_PROPAGATE_NAN) ? false : true;
......@@ -118,14 +118,14 @@ using DeviceReduceInstance =
struct Sub
{
__host__ __device__ constexpr void operator()(F16& dst, const F16& src1, const F16& src2) const
__host__ __device__ constexpr void operator()(CDataType& dst, const CDataType& src1, const CDataType& src2) const
{
dst = src1 - src2;
}
};
using DeviceElementwiseInstance =
ck::tensor_operation::device::DeviceElementwise_2D<CDataType, CDataType, CDataType, Sub, 16, 16, 8, 8>;
using DeviceElementwiseInstance = ck::tensor_operation::device::
DeviceElementwise_2D<CDataType, CDataType, CDataType, Sub, 16, 16, 8, 8, 1, 1, 1, 1, 1>;
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, PassThrough, PassThrough, PassThrough>;
......@@ -302,8 +302,8 @@ int main(int argc, char* argv[])
if(!broadcastSub.IsSupportedArgument(broadcastSub_argument_ptr.get()))
{
throw std::runtime_error(
"The runtime parameters seems not supported by the DeviceElementwise_2D instance, exiting!");
throw std::runtime_error("The runtime parameters seems not supported by the "
"DeviceElementwise_2D instance, exiting!");
};
auto broadcastSub_invoker_ptr = broadcastSub.MakeInvokerPointer();
......
......@@ -17,9 +17,17 @@ template <typename ADataType,
index_t MThreadPerBlock,
index_t NThreadPerBlock,
index_t MThreadTileSize,
index_t NThreadTileSize>
index_t NThreadTileSize,
index_t AThreadTransferSrcVectorDim,
index_t AThreadTransferSrcScalarPerVector,
index_t BThreadTransferSrcVectorDim,
index_t BThreadTransferSrcScalarPerVector,
index_t CThreadTransferSrcScalarPerVector>
struct DeviceElementwise_2D : public DeviceElementwise<ElementwiseFunctor>
{
static_assert(NThreadTileSize % AThreadTransferSrcScalarPerVector == 0 &&
NThreadTileSize % BThreadTransferSrcScalarPerVector == 0);
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
......@@ -38,11 +46,16 @@ struct DeviceElementwise_2D : public DeviceElementwise<ElementwiseFunctor>
BDataType,
CDataType,
GridDesc_M_N,
GridDesc_M_N,
GridDesc_M_N,
ElementwiseFunctor,
MThreadPerBlock,
NThreadPerBlock,
MThreadTileSize,
NThreadTileSize>;
NThreadTileSize,
AThreadTransferSrcVectorDim,
AThreadTransferSrcScalarPerVector,
BThreadTransferSrcVectorDim,
BThreadTransferSrcScalarPerVector,
CThreadTransferSrcScalarPerVector>;
struct Argument : public BaseArgument
{
......@@ -88,18 +101,12 @@ struct DeviceElementwise_2D : public DeviceElementwise<ElementwiseFunctor>
float Run(const Argument& arg, int nrepeat = 1)
{
const auto kernel = kernel_elementwise_2d<GridwiseEltwise,
const auto kernel = kernel_elementwise_2d<GridwiseEltwise,
ADataType,
BDataType,
CDataType,
GridDesc_M_N,
GridDesc_M_N,
GridDesc_M_N,
ElementwiseFunctor>;
// TODO
(void)arg;
(void)nrepeat;
(void)kernel;
float avgTime = 0;
const index_t gridSize = CalculateGridSize(arg.c_grid_desc_m_n_);
if(nrepeat == 0)
......
......@@ -10,16 +10,14 @@ template <typename GridwiseEltwise,
typename ADataType,
typename BDataType,
typename CDataType,
typename AGridDesc_M_N,
typename BGridDesc_M_N,
typename CGridDesc_M_N,
typename GridDesc_M_N,
typename ElementwiseFunctor>
__global__ void kernel_elementwise_2d(const ADataType* __restrict__ p_a_global,
const BDataType* __restrict__ p_b_global,
CDataType* __restrict__ p_c_global,
const AGridDesc_M_N a_grid_desc_m_k,
const BGridDesc_M_N b_grid_desc_m_k,
const CGridDesc_M_N c_grid_desc_m_k,
const GridDesc_M_N a_grid_desc_m_k,
const GridDesc_M_N b_grid_desc_m_k,
const GridDesc_M_N c_grid_desc_m_k,
const ElementwiseFunctor functor)
{
GridwiseEltwise::Run(p_a_global,
......@@ -34,26 +32,58 @@ __global__ void kernel_elementwise_2d(const ADataType* __restrict__ p_a_global,
template <typename ADataType,
typename BDataType,
typename CDataType,
typename AGridDesc_M_N,
typename BGridDesc_M_N,
typename CGridDesc_M_N,
typename GridDesc_M_N,
typename ElementwiseFunctor,
index_t MThreadPerBlock,
index_t NThreadPerBlock,
index_t MThreadTileSize,
index_t NThreadTileSize>
index_t NThreadTileSize,
index_t AThreadTransferSrcVectorDim,
index_t AThreadTransferSrcScalarPerVector,
index_t BThreadTransferSrcVectorDim,
index_t BThreadTransferSrcScalarPerVector,
index_t CThreadTransferSrcScalarPerVector>
struct GridwiseElementwise_2D
{
static constexpr auto thread_buf_desc_M_N = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadTileSize>{}, Number<NThreadTileSize>{}));
using ThreadBufDesc_M_N = decltype(thread_buf_desc_M_N);
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr int M_BlockTileSize = MThreadPerBlock * MThreadTileSize;
static constexpr int N_BlockTileSize = NThreadPerBlock * NThreadTileSize;
static __device__ __host__ auto CalculateElementwiseIndex(const GridDesc_M_N& grid_desc_m_n)
{
const index_t thread_id = get_thread_local_1d_id();
const index_t block_id = get_block_1d_id();
const index_t M = grid_desc_m_n.GetLength(I0);
const index_t gridSize_m = M / M_BlockTileSize;
const index_t block_2d_idx_m = block_id % gridSize_m;
const index_t block_2d_idx_n = block_id / gridSize_m;
constexpr auto thread_desc =
make_cluster_descriptor(Sequence<MThreadPerBlock, NThreadPerBlock>{}, Sequence<1, 0>{});
const auto thread_2d_idx = thread_desc.CalculateBottomIndex(make_multi_index(thread_id));
return make_multi_index(
block_2d_idx_m * M_BlockTileSize + thread_2d_idx[I0] * MThreadTileSize,
block_2d_idx_n * N_BlockTileSize + thread_2d_idx[I1] * NThreadTileSize);
}
__device__ static void Run(const ADataType* __restrict__ p_a_global,
const BDataType* __restrict__ p_b_global,
CDataType* __restrict__ p_c_global,
const AGridDesc_M_N a_grid_desc_m_n,
const BGridDesc_M_N b_grid_desc_m_n,
const CGridDesc_M_N c_grid_desc_m_n,
const GridDesc_M_N a_grid_desc_m_n,
const GridDesc_M_N b_grid_desc_m_n,
const GridDesc_M_N c_grid_desc_m_n,
const ElementwiseFunctor functor)
{
// const index_t thread_id = get_thread_local_1d_id();
// const index_t block_id = get_block_1d_id();
// printf("block_id = %d, thread_id = %d \n", block_id, thread_id);
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_global, a_grid_desc_m_n.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
......@@ -68,14 +98,53 @@ struct GridwiseElementwise_2D
StaticBuffer<AddressSpaceEnum::Vgpr, CDataType, MThreadTileSize * NThreadTileSize, true>
c_thread_buf;
// TODO - buffer_load, apply functor, buffer_store
(void)a_global_buf;
(void)b_global_buf;
const auto a_global_load_offset = CalculateElementwiseIndex(a_grid_desc_m_n);
const auto b_global_load_offset = CalculateElementwiseIndex(b_grid_desc_m_n);
auto a_global_load = ThreadwiseTensorSliceTransfer_v2<
ADataType,
ADataType,
GridDesc_M_N,
decltype(thread_buf_desc_M_N),
Sequence<MThreadTileSize, NThreadTileSize>, // SliceLengths
Sequence<0, 1>, // DimAccessOrder
AThreadTransferSrcVectorDim,
AThreadTransferSrcScalarPerVector,
1, // SrcScalarStrideInVector
false>{a_grid_desc_m_n, a_global_load_offset};
auto b_global_load = ThreadwiseTensorSliceTransfer_v2<
BDataType,
BDataType,
GridDesc_M_N,
decltype(thread_buf_desc_M_N),
Sequence<MThreadTileSize, NThreadTileSize>, // SliceLengths
Sequence<0, 1>, // DimAccessOrder
BThreadTransferSrcVectorDim,
BThreadTransferSrcScalarPerVector,
1, // SrcScalarStrideInVector
false>{b_grid_desc_m_n, b_global_load_offset};
a_global_load.Run(
a_grid_desc_m_n, a_global_buf, thread_buf_desc_M_N, make_tuple(I0, I0), a_thread_buf);
b_global_load.Run(
b_grid_desc_m_n, b_global_buf, thread_buf_desc_M_N, make_tuple(I0, I0), b_thread_buf);
static_for<0, MThreadTileSize, 1>{}([&](auto m) {
static_for<0, NThreadTileSize, 1>{}([&](auto n) {
constexpr auto offset = thread_buf_desc_M_N.CalculateOffset(make_tuple(m, n));
functor(c_thread_buf(Number<offset>{}),
a_thread_buf(Number<offset>{}),
b_thread_buf(Number<offset>{}));
});
});
// TODO - global write
(void)c_global_buf;
(void)a_thread_buf;
(void)b_thread_buf;
(void)c_thread_buf;
(void)functor;
// c_global_write.Run(
// thread_buf_desc_M_N, c_thread_buf, c_grid_desc_m_n, make_tuple(I0, I0),
// c_global_buf);
}
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment