Commit f919809d authored by rocking's avatar rocking
Browse files

Move threadPerBlock to argument

parent a41f5481
...@@ -179,7 +179,6 @@ using DeviceElementwiseSubExpInstance = ...@@ -179,7 +179,6 @@ using DeviceElementwiseSubExpInstance =
CDataType, CDataType,
EltwiseComputeDataType, EltwiseComputeDataType,
SubExp, SubExp,
256,
8>; 8>;
using DeviceElementwiseDivInstance = using DeviceElementwiseDivInstance =
...@@ -188,7 +187,6 @@ using DeviceElementwiseDivInstance = ...@@ -188,7 +187,6 @@ using DeviceElementwiseDivInstance =
CDataType, CDataType,
EltwiseComputeDataType, EltwiseComputeDataType,
Div, Div,
256,
8>; 8>;
using HostGemmInstance = ck::tensor_operation::host:: using HostGemmInstance = ck::tensor_operation::host::
...@@ -416,7 +414,8 @@ int main(int argc, char* argv[]) ...@@ -416,7 +414,8 @@ int main(int argc, char* argv[])
{StrideC, 1}, {StrideC, 1},
{0, 1}, {0, 1},
{StrideC, 1}, {StrideC, 1},
SubExp{}); SubExp{},
256);
if(!broadcastSubExp.IsSupportedArgument(broadcastSubExp_argument_ptr.get())) if(!broadcastSubExp.IsSupportedArgument(broadcastSubExp_argument_ptr.get()))
{ {
...@@ -466,7 +465,8 @@ int main(int argc, char* argv[]) ...@@ -466,7 +465,8 @@ int main(int argc, char* argv[])
{StrideC, 1}, {StrideC, 1},
{0, 1}, {0, 1},
{StrideC, 1}, {StrideC, 1},
Div{}); Div{},
256);
if(!broadcastDiv.IsSupportedArgument(broadcastDiv_argument_ptr.get())) if(!broadcastDiv.IsSupportedArgument(broadcastDiv_argument_ptr.get()))
{ {
......
...@@ -19,7 +19,8 @@ struct DeviceBinaryElementwise : public BaseOperator ...@@ -19,7 +19,8 @@ struct DeviceBinaryElementwise : public BaseOperator
const std::vector<int>& stride_a, const std::vector<int>& stride_a,
const std::vector<int>& shape_b, const std::vector<int>& shape_b,
const std::vector<int>& stride_b, const std::vector<int>& stride_b,
ElementwiseFunctor functor) = 0; ElementwiseFunctor functor,
index_t threadPerBlock) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
......
...@@ -15,7 +15,6 @@ template <typename ADataType, ...@@ -15,7 +15,6 @@ template <typename ADataType,
typename CDataType, typename CDataType,
typename ComputeDataType, typename ComputeDataType,
typename ElementwiseFunctor, typename ElementwiseFunctor,
index_t ThreadPerBlock,
index_t ScalarPerVector> index_t ScalarPerVector>
struct DeviceBinaryElementwise_2D : public DeviceBinaryElementwise<ElementwiseFunctor> struct DeviceBinaryElementwise_2D : public DeviceBinaryElementwise<ElementwiseFunctor>
{ {
...@@ -23,7 +22,8 @@ struct DeviceBinaryElementwise_2D : public DeviceBinaryElementwise<ElementwiseFu ...@@ -23,7 +22,8 @@ struct DeviceBinaryElementwise_2D : public DeviceBinaryElementwise<ElementwiseFu
static auto MakeDescriptor_M0(const std::vector<int>& shape, static auto MakeDescriptor_M0(const std::vector<int>& shape,
const std::vector<int>& stride, const std::vector<int>& stride,
index_t gridSize) index_t gridSize,
index_t threadPerBlock)
{ {
const int m = shape[0]; const int m = shape[0];
const int n = shape[1]; const int n = shape[1];
...@@ -41,7 +41,7 @@ struct DeviceBinaryElementwise_2D : public DeviceBinaryElementwise<ElementwiseFu ...@@ -41,7 +41,7 @@ struct DeviceBinaryElementwise_2D : public DeviceBinaryElementwise<ElementwiseFu
// pad // pad
const auto m0 = desc_m0.GetLength(I0); const auto m0 = desc_m0.GetLength(I0);
const index_t loop_step = gridSize * ThreadPerBlock * ScalarPerVector; const index_t loop_step = gridSize * threadPerBlock * ScalarPerVector;
const auto pad = math::integer_least_multiple(m0, loop_step) - m0; const auto pad = math::integer_least_multiple(m0, loop_step) - m0;
const auto desc_m0_pad = const auto desc_m0_pad =
transform_tensor_descriptor(desc_m0, transform_tensor_descriptor(desc_m0,
...@@ -51,7 +51,7 @@ struct DeviceBinaryElementwise_2D : public DeviceBinaryElementwise<ElementwiseFu ...@@ -51,7 +51,7 @@ struct DeviceBinaryElementwise_2D : public DeviceBinaryElementwise<ElementwiseFu
return desc_m0_pad; return desc_m0_pad;
} }
using GridDesc_M0 = decltype(MakeDescriptor_M0({1, 1}, {1, 1}, 1)); using GridDesc_M0 = decltype(MakeDescriptor_M0({1, 1}, {1, 1}, 1, 1));
using GridwiseBinEltwise = GridwiseBinaryElementwise_1D<ADataType, using GridwiseBinEltwise = GridwiseBinaryElementwise_1D<ADataType,
BDataType, BDataType,
CDataType, CDataType,
...@@ -69,16 +69,18 @@ struct DeviceBinaryElementwise_2D : public DeviceBinaryElementwise<ElementwiseFu ...@@ -69,16 +69,18 @@ struct DeviceBinaryElementwise_2D : public DeviceBinaryElementwise<ElementwiseFu
const std::vector<int>& stride_a, const std::vector<int>& stride_a,
const std::vector<int>& stride_b, const std::vector<int>& stride_b,
const std::vector<int>& stride_c, const std::vector<int>& stride_c,
ElementwiseFunctor functor) ElementwiseFunctor functor,
index_t threadPerBlock)
: p_a_(p_a), : p_a_(p_a),
p_b_(p_b), p_b_(p_b),
p_c_(p_c), p_c_(p_c),
functor_(functor), functor_(functor),
threadPerBlock_(threadPerBlock),
gridSize_(128) // FIXME - Calculate the grid size by number of CU in the future gridSize_(128) // FIXME - Calculate the grid size by number of CU in the future
{ {
a_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_a, gridSize_); a_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_a, gridSize_, threadPerBlock_);
b_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_b, gridSize_); b_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_b, gridSize_, threadPerBlock_);
c_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_c, gridSize_); c_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_c, gridSize_, threadPerBlock_);
} }
const ADataType* p_a_; const ADataType* p_a_;
...@@ -88,6 +90,7 @@ struct DeviceBinaryElementwise_2D : public DeviceBinaryElementwise<ElementwiseFu ...@@ -88,6 +90,7 @@ struct DeviceBinaryElementwise_2D : public DeviceBinaryElementwise<ElementwiseFu
GridDesc_M0 b_grid_desc_m0_; GridDesc_M0 b_grid_desc_m0_;
GridDesc_M0 c_grid_desc_m0_; GridDesc_M0 c_grid_desc_m0_;
ElementwiseFunctor functor_; ElementwiseFunctor functor_;
index_t threadPerBlock_;
index_t gridSize_; index_t gridSize_;
}; };
...@@ -102,12 +105,12 @@ struct DeviceBinaryElementwise_2D : public DeviceBinaryElementwise<ElementwiseFu ...@@ -102,12 +105,12 @@ struct DeviceBinaryElementwise_2D : public DeviceBinaryElementwise<ElementwiseFu
CDataType, CDataType,
GridDesc_M0, GridDesc_M0,
ElementwiseFunctor>; ElementwiseFunctor>;
float avgTime = 0; float avgTime = 0;
if(nrepeat == 0) if(nrepeat == 0)
{ {
launch_kernel(kernel, launch_kernel(kernel,
dim3(arg.gridSize_), dim3(arg.gridSize_),
dim3(ThreadPerBlock), dim3(arg.threadPerBlock_),
0, 0,
arg.p_a_, arg.p_a_,
arg.p_b_, arg.p_b_,
...@@ -122,7 +125,7 @@ struct DeviceBinaryElementwise_2D : public DeviceBinaryElementwise<ElementwiseFu ...@@ -122,7 +125,7 @@ struct DeviceBinaryElementwise_2D : public DeviceBinaryElementwise<ElementwiseFu
avgTime = launch_and_time_kernel(kernel, avgTime = launch_and_time_kernel(kernel,
nrepeat, nrepeat,
dim3(arg.gridSize_), dim3(arg.gridSize_),
dim3(ThreadPerBlock), dim3(arg.threadPerBlock_),
0, 0,
arg.p_a_, arg.p_a_,
arg.p_b_, arg.p_b_,
...@@ -164,7 +167,8 @@ struct DeviceBinaryElementwise_2D : public DeviceBinaryElementwise<ElementwiseFu ...@@ -164,7 +167,8 @@ struct DeviceBinaryElementwise_2D : public DeviceBinaryElementwise<ElementwiseFu
const std::vector<int>& stride_a, const std::vector<int>& stride_a,
const std::vector<int>& stride_b, const std::vector<int>& stride_b,
const std::vector<int>& stride_c, const std::vector<int>& stride_c,
ElementwiseFunctor functor) override ElementwiseFunctor functor,
index_t threadPerBlock) override
{ {
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a), return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b), static_cast<const BDataType*>(p_b),
...@@ -173,7 +177,8 @@ struct DeviceBinaryElementwise_2D : public DeviceBinaryElementwise<ElementwiseFu ...@@ -173,7 +177,8 @@ struct DeviceBinaryElementwise_2D : public DeviceBinaryElementwise<ElementwiseFu
stride_a, stride_a,
stride_b, stride_b,
stride_c, stride_c,
functor); functor,
threadPerBlock);
} }
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
...@@ -188,7 +193,6 @@ struct DeviceBinaryElementwise_2D : public DeviceBinaryElementwise<ElementwiseFu ...@@ -188,7 +193,6 @@ struct DeviceBinaryElementwise_2D : public DeviceBinaryElementwise<ElementwiseFu
// clang-format off // clang-format off
str << "DeviceBinaryElementwise_2D" str << "DeviceBinaryElementwise_2D"
<< "<" << "<"
<< "ThreadPerBlock = " << ThreadPerBlock
<< "ScalarPerVector = " << ScalarPerVector << "ScalarPerVector = " << ScalarPerVector
<< ">"; << ">";
// clang-format on // clang-format on
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment