Commit 5d36f7a2 authored by rocking's avatar rocking
Browse files

Rewrite the elementwise operation.

Let memory coalesce between block
parent 88d621ac
......@@ -171,16 +171,21 @@ struct Div
using DeviceElementwiseSubExpInstance =
ck::tensor_operation::device::DeviceBinaryElementwise_2D<CDataType,
CDataType,
CDataType,
EltwiseComputeDataType,
Sub_Exp,
256,
32,
8>;
using DeviceElementwiseDivInstance = ck::tensor_operation::device::
DeviceBinaryElementwise_2D<CDataType, CDataType, CDataType, EltwiseComputeDataType, Div, 256, 32, 8>;
CDataType,
CDataType,
EltwiseComputeDataType,
Sub_Exp,
256,
8>;
using DeviceElementwiseDivInstance =
ck::tensor_operation::device::DeviceBinaryElementwise_2D<CDataType,
CDataType,
CDataType,
EltwiseComputeDataType,
Div,
256,
8>;
using HostGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, PassThrough, PassThrough, PassThrough>;
......
......@@ -12,15 +12,14 @@ template <typename ElementwiseFunctor>
struct DeviceBinaryElementwise : public BaseOperator
{
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
const std::vector<int>& shape_a,
const std::vector<int>& stride_a,
const std::vector<int>& shape_b,
const std::vector<int>& stride_b,
ElementwiseFunctor functor) = 0;
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
const std::vector<int>& shape_a,
const std::vector<int>& stride_a,
const std::vector<int>& shape_b,
const std::vector<int>& stride_b,
ElementwiseFunctor functor) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
......
......@@ -16,15 +16,14 @@ template <typename ADataType,
typename ComputeDataType,
typename ElementwiseFunctor,
index_t ThreadPerBlock,
index_t ThreadTileSize,
index_t ScalarPerVector>
struct DeviceBinaryElementwise_2D : public DeviceBinaryElementwise<ElementwiseFunctor>
{
static_assert(ThreadTileSize % ScalarPerVector == 0);
static constexpr int BlockTileSize = ThreadPerBlock * ThreadTileSize;
static constexpr auto I0 = Number<0>{};
static constexpr auto I0 = Number<0>{};
static auto MakeDescriptor_M0(const std::vector<int>& shape, const std::vector<int>& stride)
static auto MakeDescriptor_M0(const std::vector<int>& shape,
const std::vector<int>& stride,
index_t gridSize)
{
const int m = shape[0];
const int n = shape[1];
......@@ -41,8 +40,9 @@ struct DeviceBinaryElementwise_2D : public DeviceBinaryElementwise<ElementwiseFu
make_tuple(Sequence<0>{}));
// pad
const auto m0 = desc_m0.GetLength(I0);
const auto pad = math::integer_least_multiple(m0, BlockTileSize) - m0;
const auto m0 = desc_m0.GetLength(I0);
const index_t loop_step = gridSize * ThreadPerBlock * ScalarPerVector;
const auto pad = math::integer_least_multiple(m0, loop_step) - m0;
const auto desc_m0_pad =
transform_tensor_descriptor(desc_m0,
make_tuple(make_right_pad_transform(m0, pad)),
......@@ -51,15 +51,13 @@ struct DeviceBinaryElementwise_2D : public DeviceBinaryElementwise<ElementwiseFu
return desc_m0_pad;
}
using GridDesc_M0 = decltype(MakeDescriptor_M0({1, 1}, {1, 1}));
using GridDesc_M0 = decltype(MakeDescriptor_M0({1, 1}, {1, 1}, 1));
using GridwiseBinEltwise = GridwiseBinaryElementwise_1D<ADataType,
BDataType,
CDataType,
ComputeDataType,
GridDesc_M0,
ElementwiseFunctor,
ThreadPerBlock,
ThreadTileSize,
ScalarPerVector>;
struct Argument : public BaseArgument
......@@ -75,11 +73,12 @@ struct DeviceBinaryElementwise_2D : public DeviceBinaryElementwise<ElementwiseFu
: p_a_(p_a),
p_b_(p_b),
p_c_(p_c),
a_grid_desc_m0_(MakeDescriptor_M0(shape, stride_a)),
b_grid_desc_m0_(MakeDescriptor_M0(shape, stride_b)),
c_grid_desc_m0_(MakeDescriptor_M0(shape, stride_c)),
functor_(functor)
functor_(functor),
gridSize_(128) // FIXME - Calculate the grid size by number of CU in the future
{
a_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_a, gridSize_);
b_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_b, gridSize_);
c_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_c, gridSize_);
}
const ADataType* p_a_;
......@@ -89,30 +88,25 @@ struct DeviceBinaryElementwise_2D : public DeviceBinaryElementwise<ElementwiseFu
GridDesc_M0 b_grid_desc_m0_;
GridDesc_M0 c_grid_desc_m0_;
ElementwiseFunctor functor_;
index_t gridSize_;
};
struct Invoker : public BaseInvoker
{
index_t CalculateGridSize(const GridDesc_M0& grid_desc_m0)
{
const auto gridTileSize = grid_desc_m0.GetLength(I0);
return gridTileSize / BlockTileSize;
}
float Run(const Argument& arg, int nrepeat = 1)
{
const auto kernel = kernel_elementwise_1d<GridwiseBinEltwise,
(void)arg;
const auto kernel = kernel_elementwise_1d<GridwiseBinEltwise,
ADataType,
BDataType,
CDataType,
GridDesc_M0,
ElementwiseFunctor>;
float avgTime = 0;
const index_t gridSize = CalculateGridSize(arg.c_grid_desc_m0_);
float avgTime = 0;
if(nrepeat == 0)
{
launch_kernel(kernel,
dim3(gridSize),
dim3(arg.gridSize_),
dim3(ThreadPerBlock),
0,
arg.p_a_,
......@@ -127,7 +121,7 @@ struct DeviceBinaryElementwise_2D : public DeviceBinaryElementwise<ElementwiseFu
{
avgTime = launch_and_time_kernel(kernel,
nrepeat,
dim3(gridSize),
dim3(arg.gridSize_),
dim3(ThreadPerBlock),
0,
arg.p_a_,
......@@ -157,7 +151,7 @@ struct DeviceBinaryElementwise_2D : public DeviceBinaryElementwise<ElementwiseFu
// m * n
const auto m0 = pArg->c_grid_desc_m0_.GetLength(I0);
if(m0 % BlockTileSize != 0)
if(m0 % ScalarPerVector != 0)
return false;
return true;
......@@ -195,7 +189,6 @@ struct DeviceBinaryElementwise_2D : public DeviceBinaryElementwise<ElementwiseFu
str << "DeviceBinaryElementwise_2D"
<< "<"
<< "ThreadPerBlock = " << ThreadPerBlock
<< "ThreadTileSize = " << ThreadTileSize
<< "ScalarPerVector = " << ScalarPerVector
<< ">";
// clang-format on
......
......@@ -36,13 +36,10 @@ template <typename ADataType,
typename ComputeDataType,
typename GridDesc_M0,
typename ElementwiseFunctor,
index_t ThreadPerBlock,
index_t ThreadTileSize,
index_t ScalarPerVector>
struct GridwiseBinaryElementwise_1D
{
static constexpr auto I0 = Number<0>{};
static constexpr int BlockTileSize = ThreadPerBlock * ThreadTileSize;
static constexpr auto I0 = Number<0>{};
static constexpr auto thread_desc_M0 =
make_naive_tensor_descriptor_packed(make_tuple(Number<ScalarPerVector>{}));
......@@ -50,10 +47,8 @@ struct GridwiseBinaryElementwise_1D
static __device__ __host__ auto CalculateElementwiseIndex()
{
const index_t thread_id = get_thread_local_1d_id();
const index_t block_id = get_block_1d_id();
return make_multi_index(block_id * BlockTileSize + thread_id * ScalarPerVector);
const index_t global_thread_id = get_thread_global_1d_id();
return make_multi_index(global_thread_id * ScalarPerVector);
}
__device__ static void Run(const ADataType* __restrict__ p_a_global,
......@@ -116,8 +111,13 @@ struct GridwiseBinaryElementwise_1D
false>{
c_grid_desc_m0, thread_to_global_offset, PassThrough{}};
int num_iter = ThreadTileSize / ScalarPerVector;
constexpr auto thread_to_global_step = make_multi_index(ThreadPerBlock * ScalarPerVector);
const index_t threadPerBlock = get_block_size();
const index_t blockPerGrid = get_grid_size();
const auto m0 = c_grid_desc_m0.GetLength(I0);
const index_t loop_step = blockPerGrid * threadPerBlock * ScalarPerVector;
const auto loop_step_index = make_multi_index(loop_step);
index_t num_iter = m0 / (loop_step);
do
{
// read and process ScalarPerVector elements
......@@ -140,9 +140,9 @@ struct GridwiseBinaryElementwise_1D
c_grid_desc_m0,
c_global_buf);
a_global_load.MoveSrcSliceWindow(a_grid_desc_m0, thread_to_global_step);
b_global_load.MoveSrcSliceWindow(b_grid_desc_m0, thread_to_global_step);
c_global_write.MoveDstSliceWindow(c_grid_desc_m0, thread_to_global_step);
a_global_load.MoveSrcSliceWindow(a_grid_desc_m0, loop_step_index);
b_global_load.MoveSrcSliceWindow(b_grid_desc_m0, loop_step_index);
c_global_write.MoveDstSliceWindow(c_grid_desc_m0, loop_step_index);
} while(--num_iter);
}
};
......
......@@ -7,10 +7,14 @@ __device__ constexpr index_t get_wave_size() { return CK_GPU_WAVE_SIZE; }
__device__ index_t get_thread_local_1d_id() { return threadIdx.x; }
__device__ index_t get_thread_global_1d_id() { return blockIdx.x * blockDim.x + threadIdx.x; }
__device__ index_t get_wave_local_1d_id() { return threadIdx.x / get_wave_size(); }
__device__ index_t get_block_1d_id() { return blockIdx.x; }
__device__ index_t get_grid_size() { return gridDim.x; }
__device__ index_t get_block_size() { return blockDim.x; }
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment