Commit c4d610be authored by rocking's avatar rocking
Browse files

Move thread per block to the parameter of constructor

parent 83f75313
......@@ -101,8 +101,7 @@ int main()
{Stride, 1},
{0, 1}, // broadcast in first dimension
{Stride, 1},
Add{},
256);
Add{});
if(!broadcastAdd.IsSupportedArgument(argument.get()))
{
......
......@@ -80,8 +80,7 @@ int main()
{1},
{1},
{1},
Add{},
256);
Add{});
if(!broadcastAdd.IsSupportedArgument(argument.get()))
{
......
......@@ -82,8 +82,7 @@ int main()
ck::to_int_vector(a_m.mDesc.GetStrides()),
ck::to_int_vector(b_m.mDesc.GetStrides()),
ck::to_int_vector(c_m.mDesc.GetStrides()),
Add{},
256);
Add{});
if(!broadcastAdd.IsSupportedArgument(argument.get()))
{
......
......@@ -19,6 +19,11 @@ template <typename ADataType,
index_t ScalarPerVector>
struct DeviceBinaryElementwise : public BaseOperator
{
DeviceBinaryElementwise(index_t threadPerBlock = 256)
: BaseOperator(), threadPerBlock_(threadPerBlock)
{
}
static constexpr auto I0 = Number<0>{};
template <typename Desc_M0>
......@@ -85,12 +90,11 @@ struct DeviceBinaryElementwise : public BaseOperator
p_b_(p_b),
p_c_(p_c),
functor_(functor),
threadPerBlock_(threadPerBlock),
gridSize_(120) // FIXME - Calculate the grid size by number of CU in the future
{
a_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_a, gridSize_, threadPerBlock_);
b_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_b, gridSize_, threadPerBlock_);
c_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_c, gridSize_, threadPerBlock_);
a_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_a, gridSize_, threadPerBlock);
b_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_b, gridSize_, threadPerBlock);
c_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_c, gridSize_, threadPerBlock);
}
const ADataType* p_a_;
......@@ -100,12 +104,13 @@ struct DeviceBinaryElementwise : public BaseOperator
GridDesc_M0 b_grid_desc_m0_;
GridDesc_M0 c_grid_desc_m0_;
ElementwiseFunctor functor_;
index_t threadPerBlock_;
index_t gridSize_;
};
struct Invoker : public BaseInvoker
{
Invoker(index_t threadPerBlock) : BaseInvoker(), threadPerBlock_(threadPerBlock) {}
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
const auto kernel = kernel_elementwise_1d<GridwiseBinEltwise,
......@@ -118,7 +123,7 @@ struct DeviceBinaryElementwise : public BaseOperator
float elapsed_time = launch_and_time_kernel(stream_config,
kernel,
dim3(arg.gridSize_),
dim3(arg.threadPerBlock_),
dim3(threadPerBlock_),
0,
arg.p_a_,
arg.p_b_,
......@@ -136,6 +141,8 @@ struct DeviceBinaryElementwise : public BaseOperator
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
index_t threadPerBlock_;
};
bool IsSupportedArgument(const BaseArgument* p_arg) override
......@@ -161,8 +168,7 @@ struct DeviceBinaryElementwise : public BaseOperator
std::vector<int> stride_a,
std::vector<int> stride_b,
std::vector<int> stride_c,
ElementwiseFunctor functor,
index_t threadPerBlock)
ElementwiseFunctor functor)
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
......@@ -172,12 +178,12 @@ struct DeviceBinaryElementwise : public BaseOperator
stride_b,
stride_c,
functor,
threadPerBlock);
threadPerBlock_);
}
std::unique_ptr<BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{});
return std::make_unique<Invoker>(Invoker{threadPerBlock_});
}
std::string GetTypeString() const override
......@@ -193,6 +199,8 @@ struct DeviceBinaryElementwise : public BaseOperator
return str.str();
}
index_t threadPerBlock_;
};
} // namespace device
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment