Commit 0e6bf342 authored by rocking's avatar rocking
Browse files

Rename elementwise p[ to binary elementwise

parent 5fa209af
...@@ -23,7 +23,7 @@ ...@@ -23,7 +23,7 @@
#include "device_reduce_blockwise.hpp" #include "device_reduce_blockwise.hpp"
#include "reduction_enums.hpp" #include "reduction_enums.hpp"
#include "reduction_operator_mapping.hpp" #include "reduction_operator_mapping.hpp"
#include "device_elementwise_2d.hpp" #include "device_binary_elementwise_2d.hpp"
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
...@@ -170,7 +170,7 @@ struct Div ...@@ -170,7 +170,7 @@ struct Div
}; };
using DeviceElementwiseSubExpInstance = using DeviceElementwiseSubExpInstance =
ck::tensor_operation::device::DeviceElementwise_2D<CDataType, ck::tensor_operation::device::DeviceBinaryElementwise_2D<CDataType,
CDataType, CDataType,
CDataType, CDataType,
EltwiseComputeDataType, EltwiseComputeDataType,
...@@ -180,7 +180,7 @@ using DeviceElementwiseSubExpInstance = ...@@ -180,7 +180,7 @@ using DeviceElementwiseSubExpInstance =
8>; 8>;
using DeviceElementwiseDivInstance = ck::tensor_operation::device:: using DeviceElementwiseDivInstance = ck::tensor_operation::device::
DeviceElementwise_2D<CDataType, CDataType, CDataType, EltwiseComputeDataType, Div, 256, 32, 8>; DeviceBinaryElementwise_2D<CDataType, CDataType, CDataType, EltwiseComputeDataType, Div, 256, 32, 8>;
using HostGemmInstance = ck::tensor_operation::host:: using HostGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, PassThrough, PassThrough, PassThrough>; ReferenceGemm<ADataType, BDataType, CDataType, PassThrough, PassThrough, PassThrough>;
...@@ -412,7 +412,7 @@ int main(int argc, char* argv[]) ...@@ -412,7 +412,7 @@ int main(int argc, char* argv[])
if(!broadcastSubExp.IsSupportedArgument(broadcastSubExp_argument_ptr.get())) if(!broadcastSubExp.IsSupportedArgument(broadcastSubExp_argument_ptr.get()))
{ {
throw std::runtime_error("The runtime parameters seems not supported by the " throw std::runtime_error("The runtime parameters seems not supported by the "
"DeviceElementwise_2D instance, exiting!"); "DeviceBinaryElementwise_2D instance, exiting!");
}; };
auto broadcastSubExp_invoker_ptr = broadcastSubExp.MakeInvokerPointer(); auto broadcastSubExp_invoker_ptr = broadcastSubExp.MakeInvokerPointer();
...@@ -462,7 +462,7 @@ int main(int argc, char* argv[]) ...@@ -462,7 +462,7 @@ int main(int argc, char* argv[])
if(!broadcastDiv.IsSupportedArgument(broadcastDiv_argument_ptr.get())) if(!broadcastDiv.IsSupportedArgument(broadcastDiv_argument_ptr.get()))
{ {
throw std::runtime_error("The runtime parameters seems not supported by the " throw std::runtime_error("The runtime parameters seems not supported by the "
"DeviceElementwise_2D instance, exiting!"); "DeviceBinaryElementwise_2D instance, exiting!");
}; };
auto broadcastDiv_invoker_ptr = broadcastDiv.MakeInvokerPointer(); auto broadcastDiv_invoker_ptr = broadcastDiv.MakeInvokerPointer();
......
...@@ -9,7 +9,7 @@ namespace tensor_operation { ...@@ -9,7 +9,7 @@ namespace tensor_operation {
namespace device { namespace device {
template <typename ElementwiseFunctor> template <typename ElementwiseFunctor>
struct DeviceElementwise : public BaseOperator struct DeviceBinaryElementwise : public BaseOperator
{ {
virtual std::unique_ptr<BaseArgument> virtual std::unique_ptr<BaseArgument>
......
...@@ -3,8 +3,8 @@ ...@@ -3,8 +3,8 @@
#include <vector> #include <vector>
#include "device.hpp" #include "device.hpp"
#include "device_elementwise.hpp" #include "device_binary_elementwise.hpp"
#include "gridwise_elementwise_1d.hpp" #include "gridwise_binary_elementwise_1d.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -18,7 +18,7 @@ template <typename ADataType, ...@@ -18,7 +18,7 @@ template <typename ADataType,
index_t ThreadPerBlock, index_t ThreadPerBlock,
index_t ThreadTileSize, index_t ThreadTileSize,
index_t ScalarPerVector> index_t ScalarPerVector>
struct DeviceElementwise_2D : public DeviceElementwise<ElementwiseFunctor> struct DeviceBinaryElementwise_2D : public DeviceBinaryElementwise<ElementwiseFunctor>
{ {
static_assert(ThreadTileSize % ScalarPerVector == 0); static_assert(ThreadTileSize % ScalarPerVector == 0);
static constexpr int BlockTileSize = ThreadPerBlock * ThreadTileSize; static constexpr int BlockTileSize = ThreadPerBlock * ThreadTileSize;
...@@ -52,7 +52,7 @@ struct DeviceElementwise_2D : public DeviceElementwise<ElementwiseFunctor> ...@@ -52,7 +52,7 @@ struct DeviceElementwise_2D : public DeviceElementwise<ElementwiseFunctor>
} }
using GridDesc_M0 = decltype(MakeDescriptor_M0({1, 1}, {1, 1})); using GridDesc_M0 = decltype(MakeDescriptor_M0({1, 1}, {1, 1}));
using GridwiseEltwise = GridwiseElementwise_1D<ADataType, using GridwiseBinEltwise = GridwiseBinaryElementwise_1D<ADataType,
BDataType, BDataType,
CDataType, CDataType,
ComputeDataType, ComputeDataType,
...@@ -101,7 +101,7 @@ struct DeviceElementwise_2D : public DeviceElementwise<ElementwiseFunctor> ...@@ -101,7 +101,7 @@ struct DeviceElementwise_2D : public DeviceElementwise<ElementwiseFunctor>
float Run(const Argument& arg, int nrepeat = 1) float Run(const Argument& arg, int nrepeat = 1)
{ {
const auto kernel = kernel_elementwise_1d<GridwiseEltwise, const auto kernel = kernel_elementwise_1d<GridwiseBinEltwise,
ADataType, ADataType,
BDataType, BDataType,
CDataType, CDataType,
...@@ -192,8 +192,11 @@ struct DeviceElementwise_2D : public DeviceElementwise<ElementwiseFunctor> ...@@ -192,8 +192,11 @@ struct DeviceElementwise_2D : public DeviceElementwise<ElementwiseFunctor>
auto str = std::stringstream(); auto str = std::stringstream();
// clang-format off // clang-format off
str << "DeviceElementwise_2D" str << "DeviceBinaryElementwise_2D"
<< "<" << "<"
<< "ThreadPerBlock = " << ThreadPerBlock
<< "ThreadTileSize = " << ThreadTileSize
<< "ScalarPerVector = " << ScalarPerVector
<< ">"; << ">";
// clang-format on // clang-format on
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
namespace ck { namespace ck {
template <typename GridwiseEltwise, template <typename GridwiseBinEltwise,
typename ADataType, typename ADataType,
typename BDataType, typename BDataType,
typename CDataType, typename CDataType,
...@@ -21,7 +21,7 @@ __global__ void kernel_elementwise_1d(const ADataType* __restrict__ p_a_global, ...@@ -21,7 +21,7 @@ __global__ void kernel_elementwise_1d(const ADataType* __restrict__ p_a_global,
const GridDesc_M0 c_grid_desc_m0, const GridDesc_M0 c_grid_desc_m0,
const ElementwiseFunctor functor) const ElementwiseFunctor functor)
{ {
GridwiseEltwise::Run(p_a_global, GridwiseBinEltwise::Run(p_a_global,
p_b_global, p_b_global,
p_c_global, p_c_global,
a_grid_desc_m0, a_grid_desc_m0,
...@@ -39,7 +39,7 @@ template <typename ADataType, ...@@ -39,7 +39,7 @@ template <typename ADataType,
index_t ThreadPerBlock, index_t ThreadPerBlock,
index_t ThreadTileSize, index_t ThreadTileSize,
index_t ScalarPerVector> index_t ScalarPerVector>
struct GridwiseElementwise_1D struct GridwiseBinaryElementwise_1D
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr int BlockTileSize = ThreadPerBlock * ThreadTileSize; static constexpr int BlockTileSize = ThreadPerBlock * ThreadTileSize;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment