Commit 5d36f7a2 authored by rocking's avatar rocking
Browse files

Rewrite the elementwise operation.

Let memory coalesce between block
parent 88d621ac
...@@ -171,16 +171,21 @@ struct Div ...@@ -171,16 +171,21 @@ struct Div
using DeviceElementwiseSubExpInstance = using DeviceElementwiseSubExpInstance =
ck::tensor_operation::device::DeviceBinaryElementwise_2D<CDataType, ck::tensor_operation::device::DeviceBinaryElementwise_2D<CDataType,
CDataType, CDataType,
CDataType, CDataType,
EltwiseComputeDataType, EltwiseComputeDataType,
Sub_Exp, Sub_Exp,
256, 256,
32, 8>;
8>;
using DeviceElementwiseDivInstance =
using DeviceElementwiseDivInstance = ck::tensor_operation::device:: ck::tensor_operation::device::DeviceBinaryElementwise_2D<CDataType,
DeviceBinaryElementwise_2D<CDataType, CDataType, CDataType, EltwiseComputeDataType, Div, 256, 32, 8>; CDataType,
CDataType,
EltwiseComputeDataType,
Div,
256,
8>;
using HostGemmInstance = ck::tensor_operation::host:: using HostGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, PassThrough, PassThrough, PassThrough>; ReferenceGemm<ADataType, BDataType, CDataType, PassThrough, PassThrough, PassThrough>;
......
...@@ -12,15 +12,14 @@ template <typename ElementwiseFunctor> ...@@ -12,15 +12,14 @@ template <typename ElementwiseFunctor>
struct DeviceBinaryElementwise : public BaseOperator struct DeviceBinaryElementwise : public BaseOperator
{ {
virtual std::unique_ptr<BaseArgument> virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
MakeArgumentPointer(const void* p_a, const void* p_b,
const void* p_b, void* p_c,
void* p_c, const std::vector<int>& shape_a,
const std::vector<int>& shape_a, const std::vector<int>& stride_a,
const std::vector<int>& stride_a, const std::vector<int>& shape_b,
const std::vector<int>& shape_b, const std::vector<int>& stride_b,
const std::vector<int>& stride_b, ElementwiseFunctor functor) = 0;
ElementwiseFunctor functor) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
......
...@@ -16,15 +16,14 @@ template <typename ADataType, ...@@ -16,15 +16,14 @@ template <typename ADataType,
typename ComputeDataType, typename ComputeDataType,
typename ElementwiseFunctor, typename ElementwiseFunctor,
index_t ThreadPerBlock, index_t ThreadPerBlock,
index_t ThreadTileSize,
index_t ScalarPerVector> index_t ScalarPerVector>
struct DeviceBinaryElementwise_2D : public DeviceBinaryElementwise<ElementwiseFunctor> struct DeviceBinaryElementwise_2D : public DeviceBinaryElementwise<ElementwiseFunctor>
{ {
static_assert(ThreadTileSize % ScalarPerVector == 0); static constexpr auto I0 = Number<0>{};
static constexpr int BlockTileSize = ThreadPerBlock * ThreadTileSize;
static constexpr auto I0 = Number<0>{};
static auto MakeDescriptor_M0(const std::vector<int>& shape, const std::vector<int>& stride) static auto MakeDescriptor_M0(const std::vector<int>& shape,
const std::vector<int>& stride,
index_t gridSize)
{ {
const int m = shape[0]; const int m = shape[0];
const int n = shape[1]; const int n = shape[1];
...@@ -41,8 +40,9 @@ struct DeviceBinaryElementwise_2D : public DeviceBinaryElementwise<ElementwiseFu ...@@ -41,8 +40,9 @@ struct DeviceBinaryElementwise_2D : public DeviceBinaryElementwise<ElementwiseFu
make_tuple(Sequence<0>{})); make_tuple(Sequence<0>{}));
// pad // pad
const auto m0 = desc_m0.GetLength(I0); const auto m0 = desc_m0.GetLength(I0);
const auto pad = math::integer_least_multiple(m0, BlockTileSize) - m0; const index_t loop_step = gridSize * ThreadPerBlock * ScalarPerVector;
const auto pad = math::integer_least_multiple(m0, loop_step) - m0;
const auto desc_m0_pad = const auto desc_m0_pad =
transform_tensor_descriptor(desc_m0, transform_tensor_descriptor(desc_m0,
make_tuple(make_right_pad_transform(m0, pad)), make_tuple(make_right_pad_transform(m0, pad)),
...@@ -51,15 +51,13 @@ struct DeviceBinaryElementwise_2D : public DeviceBinaryElementwise<ElementwiseFu ...@@ -51,15 +51,13 @@ struct DeviceBinaryElementwise_2D : public DeviceBinaryElementwise<ElementwiseFu
return desc_m0_pad; return desc_m0_pad;
} }
using GridDesc_M0 = decltype(MakeDescriptor_M0({1, 1}, {1, 1})); using GridDesc_M0 = decltype(MakeDescriptor_M0({1, 1}, {1, 1}, 1));
using GridwiseBinEltwise = GridwiseBinaryElementwise_1D<ADataType, using GridwiseBinEltwise = GridwiseBinaryElementwise_1D<ADataType,
BDataType, BDataType,
CDataType, CDataType,
ComputeDataType, ComputeDataType,
GridDesc_M0, GridDesc_M0,
ElementwiseFunctor, ElementwiseFunctor,
ThreadPerBlock,
ThreadTileSize,
ScalarPerVector>; ScalarPerVector>;
struct Argument : public BaseArgument struct Argument : public BaseArgument
...@@ -75,11 +73,12 @@ struct DeviceBinaryElementwise_2D : public DeviceBinaryElementwise<ElementwiseFu ...@@ -75,11 +73,12 @@ struct DeviceBinaryElementwise_2D : public DeviceBinaryElementwise<ElementwiseFu
: p_a_(p_a), : p_a_(p_a),
p_b_(p_b), p_b_(p_b),
p_c_(p_c), p_c_(p_c),
a_grid_desc_m0_(MakeDescriptor_M0(shape, stride_a)), functor_(functor),
b_grid_desc_m0_(MakeDescriptor_M0(shape, stride_b)), gridSize_(128) // FIXME - Calculate the grid size by number of CU in the future
c_grid_desc_m0_(MakeDescriptor_M0(shape, stride_c)),
functor_(functor)
{ {
a_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_a, gridSize_);
b_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_b, gridSize_);
c_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_c, gridSize_);
} }
const ADataType* p_a_; const ADataType* p_a_;
...@@ -89,30 +88,25 @@ struct DeviceBinaryElementwise_2D : public DeviceBinaryElementwise<ElementwiseFu ...@@ -89,30 +88,25 @@ struct DeviceBinaryElementwise_2D : public DeviceBinaryElementwise<ElementwiseFu
GridDesc_M0 b_grid_desc_m0_; GridDesc_M0 b_grid_desc_m0_;
GridDesc_M0 c_grid_desc_m0_; GridDesc_M0 c_grid_desc_m0_;
ElementwiseFunctor functor_; ElementwiseFunctor functor_;
index_t gridSize_;
}; };
struct Invoker : public BaseInvoker struct Invoker : public BaseInvoker
{ {
index_t CalculateGridSize(const GridDesc_M0& grid_desc_m0)
{
const auto gridTileSize = grid_desc_m0.GetLength(I0);
return gridTileSize / BlockTileSize;
}
float Run(const Argument& arg, int nrepeat = 1) float Run(const Argument& arg, int nrepeat = 1)
{ {
const auto kernel = kernel_elementwise_1d<GridwiseBinEltwise, (void)arg;
const auto kernel = kernel_elementwise_1d<GridwiseBinEltwise,
ADataType, ADataType,
BDataType, BDataType,
CDataType, CDataType,
GridDesc_M0, GridDesc_M0,
ElementwiseFunctor>; ElementwiseFunctor>;
float avgTime = 0; float avgTime = 0;
const index_t gridSize = CalculateGridSize(arg.c_grid_desc_m0_);
if(nrepeat == 0) if(nrepeat == 0)
{ {
launch_kernel(kernel, launch_kernel(kernel,
dim3(gridSize), dim3(arg.gridSize_),
dim3(ThreadPerBlock), dim3(ThreadPerBlock),
0, 0,
arg.p_a_, arg.p_a_,
...@@ -127,7 +121,7 @@ struct DeviceBinaryElementwise_2D : public DeviceBinaryElementwise<ElementwiseFu ...@@ -127,7 +121,7 @@ struct DeviceBinaryElementwise_2D : public DeviceBinaryElementwise<ElementwiseFu
{ {
avgTime = launch_and_time_kernel(kernel, avgTime = launch_and_time_kernel(kernel,
nrepeat, nrepeat,
dim3(gridSize), dim3(arg.gridSize_),
dim3(ThreadPerBlock), dim3(ThreadPerBlock),
0, 0,
arg.p_a_, arg.p_a_,
...@@ -157,7 +151,7 @@ struct DeviceBinaryElementwise_2D : public DeviceBinaryElementwise<ElementwiseFu ...@@ -157,7 +151,7 @@ struct DeviceBinaryElementwise_2D : public DeviceBinaryElementwise<ElementwiseFu
// m * n // m * n
const auto m0 = pArg->c_grid_desc_m0_.GetLength(I0); const auto m0 = pArg->c_grid_desc_m0_.GetLength(I0);
if(m0 % BlockTileSize != 0) if(m0 % ScalarPerVector != 0)
return false; return false;
return true; return true;
...@@ -195,7 +189,6 @@ struct DeviceBinaryElementwise_2D : public DeviceBinaryElementwise<ElementwiseFu ...@@ -195,7 +189,6 @@ struct DeviceBinaryElementwise_2D : public DeviceBinaryElementwise<ElementwiseFu
str << "DeviceBinaryElementwise_2D" str << "DeviceBinaryElementwise_2D"
<< "<" << "<"
<< "ThreadPerBlock = " << ThreadPerBlock << "ThreadPerBlock = " << ThreadPerBlock
<< "ThreadTileSize = " << ThreadTileSize
<< "ScalarPerVector = " << ScalarPerVector << "ScalarPerVector = " << ScalarPerVector
<< ">"; << ">";
// clang-format on // clang-format on
......
...@@ -36,13 +36,10 @@ template <typename ADataType, ...@@ -36,13 +36,10 @@ template <typename ADataType,
typename ComputeDataType, typename ComputeDataType,
typename GridDesc_M0, typename GridDesc_M0,
typename ElementwiseFunctor, typename ElementwiseFunctor,
index_t ThreadPerBlock,
index_t ThreadTileSize,
index_t ScalarPerVector> index_t ScalarPerVector>
struct GridwiseBinaryElementwise_1D struct GridwiseBinaryElementwise_1D
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr int BlockTileSize = ThreadPerBlock * ThreadTileSize;
static constexpr auto thread_desc_M0 = static constexpr auto thread_desc_M0 =
make_naive_tensor_descriptor_packed(make_tuple(Number<ScalarPerVector>{})); make_naive_tensor_descriptor_packed(make_tuple(Number<ScalarPerVector>{}));
...@@ -50,10 +47,8 @@ struct GridwiseBinaryElementwise_1D ...@@ -50,10 +47,8 @@ struct GridwiseBinaryElementwise_1D
static __device__ __host__ auto CalculateElementwiseIndex() static __device__ __host__ auto CalculateElementwiseIndex()
{ {
const index_t thread_id = get_thread_local_1d_id(); const index_t global_thread_id = get_thread_global_1d_id();
const index_t block_id = get_block_1d_id(); return make_multi_index(global_thread_id * ScalarPerVector);
return make_multi_index(block_id * BlockTileSize + thread_id * ScalarPerVector);
} }
__device__ static void Run(const ADataType* __restrict__ p_a_global, __device__ static void Run(const ADataType* __restrict__ p_a_global,
...@@ -116,8 +111,13 @@ struct GridwiseBinaryElementwise_1D ...@@ -116,8 +111,13 @@ struct GridwiseBinaryElementwise_1D
false>{ false>{
c_grid_desc_m0, thread_to_global_offset, PassThrough{}}; c_grid_desc_m0, thread_to_global_offset, PassThrough{}};
int num_iter = ThreadTileSize / ScalarPerVector; const index_t threadPerBlock = get_block_size();
constexpr auto thread_to_global_step = make_multi_index(ThreadPerBlock * ScalarPerVector); const index_t blockPerGrid = get_grid_size();
const auto m0 = c_grid_desc_m0.GetLength(I0);
const index_t loop_step = blockPerGrid * threadPerBlock * ScalarPerVector;
const auto loop_step_index = make_multi_index(loop_step);
index_t num_iter = m0 / (loop_step);
do do
{ {
// read and process ScalarPerVector elements // read and process ScalarPerVector elements
...@@ -140,9 +140,9 @@ struct GridwiseBinaryElementwise_1D ...@@ -140,9 +140,9 @@ struct GridwiseBinaryElementwise_1D
c_grid_desc_m0, c_grid_desc_m0,
c_global_buf); c_global_buf);
a_global_load.MoveSrcSliceWindow(a_grid_desc_m0, thread_to_global_step); a_global_load.MoveSrcSliceWindow(a_grid_desc_m0, loop_step_index);
b_global_load.MoveSrcSliceWindow(b_grid_desc_m0, thread_to_global_step); b_global_load.MoveSrcSliceWindow(b_grid_desc_m0, loop_step_index);
c_global_write.MoveDstSliceWindow(c_grid_desc_m0, thread_to_global_step); c_global_write.MoveDstSliceWindow(c_grid_desc_m0, loop_step_index);
} while(--num_iter); } while(--num_iter);
} }
}; };
......
...@@ -7,10 +7,14 @@ __device__ constexpr index_t get_wave_size() { return CK_GPU_WAVE_SIZE; } ...@@ -7,10 +7,14 @@ __device__ constexpr index_t get_wave_size() { return CK_GPU_WAVE_SIZE; }
__device__ index_t get_thread_local_1d_id() { return threadIdx.x; } __device__ index_t get_thread_local_1d_id() { return threadIdx.x; }
__device__ index_t get_thread_global_1d_id() { return blockIdx.x * blockDim.x + threadIdx.x; }
__device__ index_t get_wave_local_1d_id() { return threadIdx.x / get_wave_size(); } __device__ index_t get_wave_local_1d_id() { return threadIdx.x / get_wave_size(); }
__device__ index_t get_block_1d_id() { return blockIdx.x; } __device__ index_t get_block_1d_id() { return blockIdx.x; }
__device__ index_t get_grid_size() { return gridDim.x; } __device__ index_t get_grid_size() { return gridDim.x; }
__device__ index_t get_block_size() { return blockDim.x; }
} // namespace ck } // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment