Commit b456d5e5 authored by rocking's avatar rocking
Browse files

Add template argument of dim . Prepare to support multiple dimension

parent c2626122
...@@ -26,7 +26,7 @@ using EltwiseComputeDataType = F32; ...@@ -26,7 +26,7 @@ using EltwiseComputeDataType = F32;
using Add = ck::tensor_operation::binary_element_wise::Add; using Add = ck::tensor_operation::binary_element_wise::Add;
using DeviceElementwiseAddInstance = ck::tensor_operation::device:: using DeviceElementwiseAddInstance = ck::tensor_operation::device::
DeviceBinaryElementwise<F16, F16, CDataType, EltwiseComputeDataType, Add, 8>; DeviceBinaryElementwise<F16, F16, CDataType, EltwiseComputeDataType, Add, 2, 8>;
template <typename HostTensorA, template <typename HostTensorA,
typename HostTensorB, typename HostTensorB,
......
...@@ -15,15 +15,16 @@ template <typename ADataType, ...@@ -15,15 +15,16 @@ template <typename ADataType,
typename CDataType, typename CDataType,
typename ComputeDataType, typename ComputeDataType,
typename ElementwiseFunctor, typename ElementwiseFunctor,
index_t Dim,
index_t ScalarPerVector> index_t ScalarPerVector>
struct DeviceBinaryElementwise : public BaseOperator struct DeviceBinaryElementwise : public BaseOperator
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static auto MakeDescriptor_M0(const std::vector<int>& shape, static auto MakeDescriptor_M0_2d(const std::vector<int>& shape,
const std::vector<int>& stride, const std::vector<int>& stride,
index_t gridSize, index_t gridSize,
index_t threadPerBlock) index_t threadPerBlock)
{ {
const int m = shape[0]; const int m = shape[0];
const int n = shape[1]; const int n = shape[1];
...@@ -51,6 +52,17 @@ struct DeviceBinaryElementwise : public BaseOperator ...@@ -51,6 +52,17 @@ struct DeviceBinaryElementwise : public BaseOperator
return desc_m0_pad; return desc_m0_pad;
} }
static auto MakeDescriptor_M0(const std::vector<int>& shape,
const std::vector<int>& stride,
index_t gridSize,
index_t threadPerBlock)
{
if constexpr(Dim == 2)
return MakeDescriptor_M0_2d(shape, stride, gridSize, threadPerBlock);
else
return make_naive_tensor_descriptor(make_tuple(0), make_tuple(0));
}
using GridDesc_M0 = decltype(MakeDescriptor_M0({1, 1}, {1, 1}, 1, 1)); using GridDesc_M0 = decltype(MakeDescriptor_M0({1, 1}, {1, 1}, 1, 1));
using GridwiseBinEltwise = GridwiseBinaryElementwise_1D<ADataType, using GridwiseBinEltwise = GridwiseBinaryElementwise_1D<ADataType,
BDataType, BDataType,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment