Commit a760a732 authored by rocking's avatar rocking
Browse files

A kernel of elementwise_2d (except global store)

parent cb1c4731
...@@ -84,8 +84,8 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle ...@@ -84,8 +84,8 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle
8>; // CBlockTransferScalarPerVector_NWaveNPerXdl 8>; // CBlockTransferScalarPerVector_NWaveNPerXdl
// clang-format on // clang-format on
constexpr int Rank = 2; constexpr int Rank = 2;
constexpr int NumReduceDim = 1; constexpr int NumReduceDim = 1;
constexpr ck::ReduceTensorOp ReduceOpId = ck::ReduceTensorOp::MAX; constexpr ck::ReduceTensorOp ReduceOpId = ck::ReduceTensorOp::MAX;
constexpr ck::NanPropagation NanOpt = ck::NanPropagation::PROPAGATE_NAN; constexpr ck::NanPropagation NanOpt = ck::NanPropagation::PROPAGATE_NAN;
constexpr bool PropagateNan = (NanOpt == ck::NanPropagation::NOT_PROPAGATE_NAN) ? false : true; constexpr bool PropagateNan = (NanOpt == ck::NanPropagation::NOT_PROPAGATE_NAN) ? false : true;
...@@ -118,14 +118,14 @@ using DeviceReduceInstance = ...@@ -118,14 +118,14 @@ using DeviceReduceInstance =
struct Sub struct Sub
{ {
__host__ __device__ constexpr void operator()(F16& dst, const F16& src1, const F16& src2) const __host__ __device__ constexpr void operator()(CDataType& dst, const CDataType& src1, const CDataType& src2) const
{ {
dst = src1 - src2; dst = src1 - src2;
} }
}; };
using DeviceElementwiseInstance = using DeviceElementwiseInstance = ck::tensor_operation::device::
ck::tensor_operation::device::DeviceElementwise_2D<CDataType, CDataType, CDataType, Sub, 16, 16, 8, 8>; DeviceElementwise_2D<CDataType, CDataType, CDataType, Sub, 16, 16, 8, 8, 1, 1, 1, 1, 1>;
using ReferenceGemmInstance = ck::tensor_operation::host:: using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, PassThrough, PassThrough, PassThrough>; ReferenceGemm<ADataType, BDataType, CDataType, PassThrough, PassThrough, PassThrough>;
...@@ -302,8 +302,8 @@ int main(int argc, char* argv[]) ...@@ -302,8 +302,8 @@ int main(int argc, char* argv[])
if(!broadcastSub.IsSupportedArgument(broadcastSub_argument_ptr.get())) if(!broadcastSub.IsSupportedArgument(broadcastSub_argument_ptr.get()))
{ {
throw std::runtime_error( throw std::runtime_error("The runtime parameters seems not supported by the "
"The runtime parameters seems not supported by the DeviceElementwise_2D instance, exiting!"); "DeviceElementwise_2D instance, exiting!");
}; };
auto broadcastSub_invoker_ptr = broadcastSub.MakeInvokerPointer(); auto broadcastSub_invoker_ptr = broadcastSub.MakeInvokerPointer();
......
...@@ -17,9 +17,17 @@ template <typename ADataType, ...@@ -17,9 +17,17 @@ template <typename ADataType,
index_t MThreadPerBlock, index_t MThreadPerBlock,
index_t NThreadPerBlock, index_t NThreadPerBlock,
index_t MThreadTileSize, index_t MThreadTileSize,
index_t NThreadTileSize> index_t NThreadTileSize,
index_t AThreadTransferSrcVectorDim,
index_t AThreadTransferSrcScalarPerVector,
index_t BThreadTransferSrcVectorDim,
index_t BThreadTransferSrcScalarPerVector,
index_t CThreadTransferSrcScalarPerVector>
struct DeviceElementwise_2D : public DeviceElementwise<ElementwiseFunctor> struct DeviceElementwise_2D : public DeviceElementwise<ElementwiseFunctor>
{ {
static_assert(NThreadTileSize % AThreadTransferSrcScalarPerVector == 0 &&
NThreadTileSize % BThreadTransferSrcScalarPerVector == 0);
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
...@@ -38,11 +46,16 @@ struct DeviceElementwise_2D : public DeviceElementwise<ElementwiseFunctor> ...@@ -38,11 +46,16 @@ struct DeviceElementwise_2D : public DeviceElementwise<ElementwiseFunctor>
BDataType, BDataType,
CDataType, CDataType,
GridDesc_M_N, GridDesc_M_N,
GridDesc_M_N,
GridDesc_M_N,
ElementwiseFunctor, ElementwiseFunctor,
MThreadPerBlock,
NThreadPerBlock,
MThreadTileSize, MThreadTileSize,
NThreadTileSize>; NThreadTileSize,
AThreadTransferSrcVectorDim,
AThreadTransferSrcScalarPerVector,
BThreadTransferSrcVectorDim,
BThreadTransferSrcScalarPerVector,
CThreadTransferSrcScalarPerVector>;
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
...@@ -88,18 +101,12 @@ struct DeviceElementwise_2D : public DeviceElementwise<ElementwiseFunctor> ...@@ -88,18 +101,12 @@ struct DeviceElementwise_2D : public DeviceElementwise<ElementwiseFunctor>
float Run(const Argument& arg, int nrepeat = 1) float Run(const Argument& arg, int nrepeat = 1)
{ {
const auto kernel = kernel_elementwise_2d<GridwiseEltwise, const auto kernel = kernel_elementwise_2d<GridwiseEltwise,
ADataType, ADataType,
BDataType, BDataType,
CDataType, CDataType,
GridDesc_M_N, GridDesc_M_N,
GridDesc_M_N,
GridDesc_M_N,
ElementwiseFunctor>; ElementwiseFunctor>;
// TODO
(void)arg;
(void)nrepeat;
(void)kernel;
float avgTime = 0; float avgTime = 0;
const index_t gridSize = CalculateGridSize(arg.c_grid_desc_m_n_); const index_t gridSize = CalculateGridSize(arg.c_grid_desc_m_n_);
if(nrepeat == 0) if(nrepeat == 0)
......
...@@ -10,16 +10,14 @@ template <typename GridwiseEltwise, ...@@ -10,16 +10,14 @@ template <typename GridwiseEltwise,
typename ADataType, typename ADataType,
typename BDataType, typename BDataType,
typename CDataType, typename CDataType,
typename AGridDesc_M_N, typename GridDesc_M_N,
typename BGridDesc_M_N,
typename CGridDesc_M_N,
typename ElementwiseFunctor> typename ElementwiseFunctor>
__global__ void kernel_elementwise_2d(const ADataType* __restrict__ p_a_global, __global__ void kernel_elementwise_2d(const ADataType* __restrict__ p_a_global,
const BDataType* __restrict__ p_b_global, const BDataType* __restrict__ p_b_global,
CDataType* __restrict__ p_c_global, CDataType* __restrict__ p_c_global,
const AGridDesc_M_N a_grid_desc_m_k, const GridDesc_M_N a_grid_desc_m_k,
const BGridDesc_M_N b_grid_desc_m_k, const GridDesc_M_N b_grid_desc_m_k,
const CGridDesc_M_N c_grid_desc_m_k, const GridDesc_M_N c_grid_desc_m_k,
const ElementwiseFunctor functor) const ElementwiseFunctor functor)
{ {
GridwiseEltwise::Run(p_a_global, GridwiseEltwise::Run(p_a_global,
...@@ -34,26 +32,58 @@ __global__ void kernel_elementwise_2d(const ADataType* __restrict__ p_a_global, ...@@ -34,26 +32,58 @@ __global__ void kernel_elementwise_2d(const ADataType* __restrict__ p_a_global,
template <typename ADataType, template <typename ADataType,
typename BDataType, typename BDataType,
typename CDataType, typename CDataType,
typename AGridDesc_M_N, typename GridDesc_M_N,
typename BGridDesc_M_N,
typename CGridDesc_M_N,
typename ElementwiseFunctor, typename ElementwiseFunctor,
index_t MThreadPerBlock,
index_t NThreadPerBlock,
index_t MThreadTileSize, index_t MThreadTileSize,
index_t NThreadTileSize> index_t NThreadTileSize,
index_t AThreadTransferSrcVectorDim,
index_t AThreadTransferSrcScalarPerVector,
index_t BThreadTransferSrcVectorDim,
index_t BThreadTransferSrcScalarPerVector,
index_t CThreadTransferSrcScalarPerVector>
struct GridwiseElementwise_2D struct GridwiseElementwise_2D
{ {
static constexpr auto thread_buf_desc_M_N = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadTileSize>{}, Number<NThreadTileSize>{}));
using ThreadBufDesc_M_N = decltype(thread_buf_desc_M_N);
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr int M_BlockTileSize = MThreadPerBlock * MThreadTileSize;
static constexpr int N_BlockTileSize = NThreadPerBlock * NThreadTileSize;
static __device__ __host__ auto CalculateElementwiseIndex(const GridDesc_M_N& grid_desc_m_n)
{
const index_t thread_id = get_thread_local_1d_id();
const index_t block_id = get_block_1d_id();
const index_t M = grid_desc_m_n.GetLength(I0);
const index_t gridSize_m = M / M_BlockTileSize;
const index_t block_2d_idx_m = block_id % gridSize_m;
const index_t block_2d_idx_n = block_id / gridSize_m;
constexpr auto thread_desc =
make_cluster_descriptor(Sequence<MThreadPerBlock, NThreadPerBlock>{}, Sequence<1, 0>{});
const auto thread_2d_idx = thread_desc.CalculateBottomIndex(make_multi_index(thread_id));
return make_multi_index(
block_2d_idx_m * M_BlockTileSize + thread_2d_idx[I0] * MThreadTileSize,
block_2d_idx_n * N_BlockTileSize + thread_2d_idx[I1] * NThreadTileSize);
}
__device__ static void Run(const ADataType* __restrict__ p_a_global, __device__ static void Run(const ADataType* __restrict__ p_a_global,
const BDataType* __restrict__ p_b_global, const BDataType* __restrict__ p_b_global,
CDataType* __restrict__ p_c_global, CDataType* __restrict__ p_c_global,
const AGridDesc_M_N a_grid_desc_m_n, const GridDesc_M_N a_grid_desc_m_n,
const BGridDesc_M_N b_grid_desc_m_n, const GridDesc_M_N b_grid_desc_m_n,
const CGridDesc_M_N c_grid_desc_m_n, const GridDesc_M_N c_grid_desc_m_n,
const ElementwiseFunctor functor) const ElementwiseFunctor functor)
{ {
// const index_t thread_id = get_thread_local_1d_id();
// const index_t block_id = get_block_1d_id();
// printf("block_id = %d, thread_id = %d \n", block_id, thread_id);
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_global, a_grid_desc_m_n.GetElementSpaceSize()); p_a_global, a_grid_desc_m_n.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
...@@ -68,14 +98,53 @@ struct GridwiseElementwise_2D ...@@ -68,14 +98,53 @@ struct GridwiseElementwise_2D
StaticBuffer<AddressSpaceEnum::Vgpr, CDataType, MThreadTileSize * NThreadTileSize, true> StaticBuffer<AddressSpaceEnum::Vgpr, CDataType, MThreadTileSize * NThreadTileSize, true>
c_thread_buf; c_thread_buf;
// TODO - buffer_load, apply functor, buffer_store const auto a_global_load_offset = CalculateElementwiseIndex(a_grid_desc_m_n);
(void)a_global_buf; const auto b_global_load_offset = CalculateElementwiseIndex(b_grid_desc_m_n);
(void)b_global_buf;
auto a_global_load = ThreadwiseTensorSliceTransfer_v2<
ADataType,
ADataType,
GridDesc_M_N,
decltype(thread_buf_desc_M_N),
Sequence<MThreadTileSize, NThreadTileSize>, // SliceLengths
Sequence<0, 1>, // DimAccessOrder
AThreadTransferSrcVectorDim,
AThreadTransferSrcScalarPerVector,
1, // SrcScalarStrideInVector
false>{a_grid_desc_m_n, a_global_load_offset};
auto b_global_load = ThreadwiseTensorSliceTransfer_v2<
BDataType,
BDataType,
GridDesc_M_N,
decltype(thread_buf_desc_M_N),
Sequence<MThreadTileSize, NThreadTileSize>, // SliceLengths
Sequence<0, 1>, // DimAccessOrder
BThreadTransferSrcVectorDim,
BThreadTransferSrcScalarPerVector,
1, // SrcScalarStrideInVector
false>{b_grid_desc_m_n, b_global_load_offset};
a_global_load.Run(
a_grid_desc_m_n, a_global_buf, thread_buf_desc_M_N, make_tuple(I0, I0), a_thread_buf);
b_global_load.Run(
b_grid_desc_m_n, b_global_buf, thread_buf_desc_M_N, make_tuple(I0, I0), b_thread_buf);
static_for<0, MThreadTileSize, 1>{}([&](auto m) {
static_for<0, NThreadTileSize, 1>{}([&](auto n) {
constexpr auto offset = thread_buf_desc_M_N.CalculateOffset(make_tuple(m, n));
functor(c_thread_buf(Number<offset>{}),
a_thread_buf(Number<offset>{}),
b_thread_buf(Number<offset>{}));
});
});
// TODO - global write
(void)c_global_buf; (void)c_global_buf;
(void)a_thread_buf; // c_global_write.Run(
(void)b_thread_buf; // thread_buf_desc_M_N, c_thread_buf, c_grid_desc_m_n, make_tuple(I0, I0),
(void)c_thread_buf; // c_global_buf);
(void)functor;
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment