Commit f919809d authored by rocking's avatar rocking
Browse files

Move threadPerBlock to argument

parent a41f5481
......@@ -179,7 +179,6 @@ using DeviceElementwiseSubExpInstance =
CDataType,
EltwiseComputeDataType,
SubExp,
256,
8>;
using DeviceElementwiseDivInstance =
......@@ -188,7 +187,6 @@ using DeviceElementwiseDivInstance =
CDataType,
EltwiseComputeDataType,
Div,
256,
8>;
using HostGemmInstance = ck::tensor_operation::host::
......@@ -416,7 +414,8 @@ int main(int argc, char* argv[])
{StrideC, 1},
{0, 1},
{StrideC, 1},
SubExp{});
SubExp{},
256);
if(!broadcastSubExp.IsSupportedArgument(broadcastSubExp_argument_ptr.get()))
{
......@@ -466,7 +465,8 @@ int main(int argc, char* argv[])
{StrideC, 1},
{0, 1},
{StrideC, 1},
Div{});
Div{},
256);
if(!broadcastDiv.IsSupportedArgument(broadcastDiv_argument_ptr.get()))
{
......
......@@ -19,7 +19,8 @@ struct DeviceBinaryElementwise : public BaseOperator
const std::vector<int>& stride_a,
const std::vector<int>& shape_b,
const std::vector<int>& stride_b,
ElementwiseFunctor functor) = 0;
ElementwiseFunctor functor,
index_t threadPerBlock) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
......
......@@ -15,7 +15,6 @@ template <typename ADataType,
typename CDataType,
typename ComputeDataType,
typename ElementwiseFunctor,
index_t ThreadPerBlock,
index_t ScalarPerVector>
struct DeviceBinaryElementwise_2D : public DeviceBinaryElementwise<ElementwiseFunctor>
{
......@@ -23,7 +22,8 @@ struct DeviceBinaryElementwise_2D : public DeviceBinaryElementwise<ElementwiseFu
static auto MakeDescriptor_M0(const std::vector<int>& shape,
const std::vector<int>& stride,
index_t gridSize)
index_t gridSize,
index_t threadPerBlock)
{
const int m = shape[0];
const int n = shape[1];
......@@ -41,7 +41,7 @@ struct DeviceBinaryElementwise_2D : public DeviceBinaryElementwise<ElementwiseFu
// pad
const auto m0 = desc_m0.GetLength(I0);
const index_t loop_step = gridSize * ThreadPerBlock * ScalarPerVector;
const index_t loop_step = gridSize * threadPerBlock * ScalarPerVector;
const auto pad = math::integer_least_multiple(m0, loop_step) - m0;
const auto desc_m0_pad =
transform_tensor_descriptor(desc_m0,
......@@ -51,7 +51,7 @@ struct DeviceBinaryElementwise_2D : public DeviceBinaryElementwise<ElementwiseFu
return desc_m0_pad;
}
using GridDesc_M0 = decltype(MakeDescriptor_M0({1, 1}, {1, 1}, 1));
using GridDesc_M0 = decltype(MakeDescriptor_M0({1, 1}, {1, 1}, 1, 1));
using GridwiseBinEltwise = GridwiseBinaryElementwise_1D<ADataType,
BDataType,
CDataType,
......@@ -69,16 +69,18 @@ struct DeviceBinaryElementwise_2D : public DeviceBinaryElementwise<ElementwiseFu
const std::vector<int>& stride_a,
const std::vector<int>& stride_b,
const std::vector<int>& stride_c,
ElementwiseFunctor functor)
ElementwiseFunctor functor,
index_t threadPerBlock)
: p_a_(p_a),
p_b_(p_b),
p_c_(p_c),
functor_(functor),
threadPerBlock_(threadPerBlock),
gridSize_(128) // FIXME - Calculate the grid size by number of CU in the future
{
a_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_a, gridSize_);
b_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_b, gridSize_);
c_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_c, gridSize_);
a_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_a, gridSize_, threadPerBlock_);
b_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_b, gridSize_, threadPerBlock_);
c_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_c, gridSize_, threadPerBlock_);
}
const ADataType* p_a_;
......@@ -88,6 +90,7 @@ struct DeviceBinaryElementwise_2D : public DeviceBinaryElementwise<ElementwiseFu
GridDesc_M0 b_grid_desc_m0_;
GridDesc_M0 c_grid_desc_m0_;
ElementwiseFunctor functor_;
index_t threadPerBlock_;
index_t gridSize_;
};
......@@ -102,12 +105,12 @@ struct DeviceBinaryElementwise_2D : public DeviceBinaryElementwise<ElementwiseFu
CDataType,
GridDesc_M0,
ElementwiseFunctor>;
float avgTime = 0;
float avgTime = 0;
if(nrepeat == 0)
{
launch_kernel(kernel,
dim3(arg.gridSize_),
dim3(ThreadPerBlock),
dim3(arg.threadPerBlock_),
0,
arg.p_a_,
arg.p_b_,
......@@ -122,7 +125,7 @@ struct DeviceBinaryElementwise_2D : public DeviceBinaryElementwise<ElementwiseFu
avgTime = launch_and_time_kernel(kernel,
nrepeat,
dim3(arg.gridSize_),
dim3(ThreadPerBlock),
dim3(arg.threadPerBlock_),
0,
arg.p_a_,
arg.p_b_,
......@@ -164,7 +167,8 @@ struct DeviceBinaryElementwise_2D : public DeviceBinaryElementwise<ElementwiseFu
const std::vector<int>& stride_a,
const std::vector<int>& stride_b,
const std::vector<int>& stride_c,
ElementwiseFunctor functor) override
ElementwiseFunctor functor,
index_t threadPerBlock) override
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
......@@ -173,7 +177,8 @@ struct DeviceBinaryElementwise_2D : public DeviceBinaryElementwise<ElementwiseFu
stride_a,
stride_b,
stride_c,
functor);
functor,
threadPerBlock);
}
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
......@@ -188,7 +193,6 @@ struct DeviceBinaryElementwise_2D : public DeviceBinaryElementwise<ElementwiseFu
// clang-format off
str << "DeviceBinaryElementwise_2D"
<< "<"
<< "ThreadPerBlock = " << ThreadPerBlock
<< "ScalarPerVector = " << ScalarPerVector
<< ">";
// clang-format on
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment