Commit c4d610be authored by rocking's avatar rocking
Browse files

Move thread per block to the parameter of constructor

parent 83f75313
...@@ -101,8 +101,7 @@ int main() ...@@ -101,8 +101,7 @@ int main()
{Stride, 1}, {Stride, 1},
{0, 1}, // broadcast in first dimension {0, 1}, // broadcast in first dimension
{Stride, 1}, {Stride, 1},
Add{}, Add{});
256);
if(!broadcastAdd.IsSupportedArgument(argument.get())) if(!broadcastAdd.IsSupportedArgument(argument.get()))
{ {
......
...@@ -80,8 +80,7 @@ int main() ...@@ -80,8 +80,7 @@ int main()
{1}, {1},
{1}, {1},
{1}, {1},
Add{}, Add{});
256);
if(!broadcastAdd.IsSupportedArgument(argument.get())) if(!broadcastAdd.IsSupportedArgument(argument.get()))
{ {
......
...@@ -82,8 +82,7 @@ int main() ...@@ -82,8 +82,7 @@ int main()
ck::to_int_vector(a_m.mDesc.GetStrides()), ck::to_int_vector(a_m.mDesc.GetStrides()),
ck::to_int_vector(b_m.mDesc.GetStrides()), ck::to_int_vector(b_m.mDesc.GetStrides()),
ck::to_int_vector(c_m.mDesc.GetStrides()), ck::to_int_vector(c_m.mDesc.GetStrides()),
Add{}, Add{});
256);
if(!broadcastAdd.IsSupportedArgument(argument.get())) if(!broadcastAdd.IsSupportedArgument(argument.get()))
{ {
......
...@@ -19,6 +19,11 @@ template <typename ADataType, ...@@ -19,6 +19,11 @@ template <typename ADataType,
index_t ScalarPerVector> index_t ScalarPerVector>
struct DeviceBinaryElementwise : public BaseOperator struct DeviceBinaryElementwise : public BaseOperator
{ {
DeviceBinaryElementwise(index_t threadPerBlock = 256)
: BaseOperator(), threadPerBlock_(threadPerBlock)
{
}
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
template <typename Desc_M0> template <typename Desc_M0>
...@@ -85,12 +90,11 @@ struct DeviceBinaryElementwise : public BaseOperator ...@@ -85,12 +90,11 @@ struct DeviceBinaryElementwise : public BaseOperator
p_b_(p_b), p_b_(p_b),
p_c_(p_c), p_c_(p_c),
functor_(functor), functor_(functor),
threadPerBlock_(threadPerBlock),
gridSize_(120) // FIXME - Calculate the grid size by number of CU in the future gridSize_(120) // FIXME - Calculate the grid size by number of CU in the future
{ {
a_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_a, gridSize_, threadPerBlock_); a_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_a, gridSize_, threadPerBlock);
b_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_b, gridSize_, threadPerBlock_); b_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_b, gridSize_, threadPerBlock);
c_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_c, gridSize_, threadPerBlock_); c_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_c, gridSize_, threadPerBlock);
} }
const ADataType* p_a_; const ADataType* p_a_;
...@@ -100,12 +104,13 @@ struct DeviceBinaryElementwise : public BaseOperator ...@@ -100,12 +104,13 @@ struct DeviceBinaryElementwise : public BaseOperator
GridDesc_M0 b_grid_desc_m0_; GridDesc_M0 b_grid_desc_m0_;
GridDesc_M0 c_grid_desc_m0_; GridDesc_M0 c_grid_desc_m0_;
ElementwiseFunctor functor_; ElementwiseFunctor functor_;
index_t threadPerBlock_;
index_t gridSize_; index_t gridSize_;
}; };
struct Invoker : public BaseInvoker struct Invoker : public BaseInvoker
{ {
Invoker(index_t threadPerBlock) : BaseInvoker(), threadPerBlock_(threadPerBlock) {}
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
const auto kernel = kernel_elementwise_1d<GridwiseBinEltwise, const auto kernel = kernel_elementwise_1d<GridwiseBinEltwise,
...@@ -118,7 +123,7 @@ struct DeviceBinaryElementwise : public BaseOperator ...@@ -118,7 +123,7 @@ struct DeviceBinaryElementwise : public BaseOperator
float elapsed_time = launch_and_time_kernel(stream_config, float elapsed_time = launch_and_time_kernel(stream_config,
kernel, kernel,
dim3(arg.gridSize_), dim3(arg.gridSize_),
dim3(arg.threadPerBlock_), dim3(threadPerBlock_),
0, 0,
arg.p_a_, arg.p_a_,
arg.p_b_, arg.p_b_,
...@@ -136,6 +141,8 @@ struct DeviceBinaryElementwise : public BaseOperator ...@@ -136,6 +141,8 @@ struct DeviceBinaryElementwise : public BaseOperator
{ {
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config); return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
} }
index_t threadPerBlock_;
}; };
bool IsSupportedArgument(const BaseArgument* p_arg) override bool IsSupportedArgument(const BaseArgument* p_arg) override
...@@ -161,8 +168,7 @@ struct DeviceBinaryElementwise : public BaseOperator ...@@ -161,8 +168,7 @@ struct DeviceBinaryElementwise : public BaseOperator
std::vector<int> stride_a, std::vector<int> stride_a,
std::vector<int> stride_b, std::vector<int> stride_b,
std::vector<int> stride_c, std::vector<int> stride_c,
ElementwiseFunctor functor, ElementwiseFunctor functor)
index_t threadPerBlock)
{ {
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a), return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b), static_cast<const BDataType*>(p_b),
...@@ -172,12 +178,12 @@ struct DeviceBinaryElementwise : public BaseOperator ...@@ -172,12 +178,12 @@ struct DeviceBinaryElementwise : public BaseOperator
stride_b, stride_b,
stride_c, stride_c,
functor, functor,
threadPerBlock); threadPerBlock_);
} }
std::unique_ptr<BaseInvoker> MakeInvokerPointer() std::unique_ptr<BaseInvoker> MakeInvokerPointer()
{ {
return std::make_unique<Invoker>(Invoker{}); return std::make_unique<Invoker>(Invoker{threadPerBlock_});
} }
std::string GetTypeString() const override std::string GetTypeString() const override
...@@ -193,6 +199,8 @@ struct DeviceBinaryElementwise : public BaseOperator ...@@ -193,6 +199,8 @@ struct DeviceBinaryElementwise : public BaseOperator
return str.str(); return str.str();
} }
index_t threadPerBlock_;
}; };
} // namespace device } // namespace device
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment