Commit b456d5e5 authored by rocking's avatar rocking
Browse files

Add template argument of dim . Prepare to support multiple dimension

parent c2626122
......@@ -26,7 +26,7 @@ using EltwiseComputeDataType = F32;
using Add = ck::tensor_operation::binary_element_wise::Add;
using DeviceElementwiseAddInstance = ck::tensor_operation::device::
DeviceBinaryElementwise<F16, F16, CDataType, EltwiseComputeDataType, Add, 8>;
DeviceBinaryElementwise<F16, F16, CDataType, EltwiseComputeDataType, Add, 2, 8>;
template <typename HostTensorA,
typename HostTensorB,
......
......@@ -15,15 +15,16 @@ template <typename ADataType,
typename CDataType,
typename ComputeDataType,
typename ElementwiseFunctor,
index_t Dim,
index_t ScalarPerVector>
struct DeviceBinaryElementwise : public BaseOperator
{
static constexpr auto I0 = Number<0>{};
static auto MakeDescriptor_M0(const std::vector<int>& shape,
const std::vector<int>& stride,
index_t gridSize,
index_t threadPerBlock)
static auto MakeDescriptor_M0_2d(const std::vector<int>& shape,
const std::vector<int>& stride,
index_t gridSize,
index_t threadPerBlock)
{
const int m = shape[0];
const int n = shape[1];
......@@ -51,6 +52,17 @@ struct DeviceBinaryElementwise : public BaseOperator
return desc_m0_pad;
}
static auto MakeDescriptor_M0(const std::vector<int>& shape,
const std::vector<int>& stride,
index_t gridSize,
index_t threadPerBlock)
{
if constexpr(Dim == 2)
return MakeDescriptor_M0_2d(shape, stride, gridSize, threadPerBlock);
else
return make_naive_tensor_descriptor(make_tuple(0), make_tuple(0));
}
using GridDesc_M0 = decltype(MakeDescriptor_M0({1, 1}, {1, 1}, 1, 1));
using GridwiseBinEltwise = GridwiseBinaryElementwise_1D<ADataType,
BDataType,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment