Commit 0e6bf342 authored by rocking's avatar rocking
Browse files

Rename elementwise p[ to binary elementwise

parent 5fa209af
......@@ -23,7 +23,7 @@
#include "device_reduce_blockwise.hpp"
#include "reduction_enums.hpp"
#include "reduction_operator_mapping.hpp"
#include "device_elementwise_2d.hpp"
#include "device_binary_elementwise_2d.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
......@@ -170,7 +170,7 @@ struct Div
};
using DeviceElementwiseSubExpInstance =
ck::tensor_operation::device::DeviceElementwise_2D<CDataType,
ck::tensor_operation::device::DeviceBinaryElementwise_2D<CDataType,
CDataType,
CDataType,
EltwiseComputeDataType,
......@@ -180,7 +180,7 @@ using DeviceElementwiseSubExpInstance =
8>;
using DeviceElementwiseDivInstance = ck::tensor_operation::device::
DeviceElementwise_2D<CDataType, CDataType, CDataType, EltwiseComputeDataType, Div, 256, 32, 8>;
DeviceBinaryElementwise_2D<CDataType, CDataType, CDataType, EltwiseComputeDataType, Div, 256, 32, 8>;
using HostGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, PassThrough, PassThrough, PassThrough>;
......@@ -412,7 +412,7 @@ int main(int argc, char* argv[])
if(!broadcastSubExp.IsSupportedArgument(broadcastSubExp_argument_ptr.get()))
{
throw std::runtime_error("The runtime parameters seems not supported by the "
"DeviceElementwise_2D instance, exiting!");
"DeviceBinaryElementwise_2D instance, exiting!");
};
auto broadcastSubExp_invoker_ptr = broadcastSubExp.MakeInvokerPointer();
......@@ -462,7 +462,7 @@ int main(int argc, char* argv[])
if(!broadcastDiv.IsSupportedArgument(broadcastDiv_argument_ptr.get()))
{
throw std::runtime_error("The runtime parameters seems not supported by the "
"DeviceElementwise_2D instance, exiting!");
"DeviceBinaryElementwise_2D instance, exiting!");
};
auto broadcastDiv_invoker_ptr = broadcastDiv.MakeInvokerPointer();
......
......@@ -9,7 +9,7 @@ namespace tensor_operation {
namespace device {
template <typename ElementwiseFunctor>
struct DeviceElementwise : public BaseOperator
struct DeviceBinaryElementwise : public BaseOperator
{
virtual std::unique_ptr<BaseArgument>
......
......@@ -3,8 +3,8 @@
#include <vector>
#include "device.hpp"
#include "device_elementwise.hpp"
#include "gridwise_elementwise_1d.hpp"
#include "device_binary_elementwise.hpp"
#include "gridwise_binary_elementwise_1d.hpp"
namespace ck {
namespace tensor_operation {
......@@ -18,7 +18,7 @@ template <typename ADataType,
index_t ThreadPerBlock,
index_t ThreadTileSize,
index_t ScalarPerVector>
struct DeviceElementwise_2D : public DeviceElementwise<ElementwiseFunctor>
struct DeviceBinaryElementwise_2D : public DeviceBinaryElementwise<ElementwiseFunctor>
{
static_assert(ThreadTileSize % ScalarPerVector == 0);
static constexpr int BlockTileSize = ThreadPerBlock * ThreadTileSize;
......@@ -51,16 +51,16 @@ struct DeviceElementwise_2D : public DeviceElementwise<ElementwiseFunctor>
return desc_m0_pad;
}
using GridDesc_M0 = decltype(MakeDescriptor_M0({1, 1}, {1, 1}));
using GridwiseEltwise = GridwiseElementwise_1D<ADataType,
BDataType,
CDataType,
ComputeDataType,
GridDesc_M0,
ElementwiseFunctor,
ThreadPerBlock,
ThreadTileSize,
ScalarPerVector>;
using GridDesc_M0 = decltype(MakeDescriptor_M0({1, 1}, {1, 1}));
using GridwiseBinEltwise = GridwiseBinaryElementwise_1D<ADataType,
BDataType,
CDataType,
ComputeDataType,
GridDesc_M0,
ElementwiseFunctor,
ThreadPerBlock,
ThreadTileSize,
ScalarPerVector>;
struct Argument : public BaseArgument
{
......@@ -101,7 +101,7 @@ struct DeviceElementwise_2D : public DeviceElementwise<ElementwiseFunctor>
float Run(const Argument& arg, int nrepeat = 1)
{
const auto kernel = kernel_elementwise_1d<GridwiseEltwise,
const auto kernel = kernel_elementwise_1d<GridwiseBinEltwise,
ADataType,
BDataType,
CDataType,
......@@ -192,8 +192,11 @@ struct DeviceElementwise_2D : public DeviceElementwise<ElementwiseFunctor>
auto str = std::stringstream();
// clang-format off
str << "DeviceElementwise_2D"
str << "DeviceBinaryElementwise_2D"
<< "<"
<< "ThreadPerBlock = " << ThreadPerBlock
<< "ThreadTileSize = " << ThreadTileSize
<< "ScalarPerVector = " << ScalarPerVector
<< ">";
// clang-format on
......
......@@ -7,7 +7,7 @@
namespace ck {
template <typename GridwiseEltwise,
template <typename GridwiseBinEltwise,
typename ADataType,
typename BDataType,
typename CDataType,
......@@ -21,13 +21,13 @@ __global__ void kernel_elementwise_1d(const ADataType* __restrict__ p_a_global,
const GridDesc_M0 c_grid_desc_m0,
const ElementwiseFunctor functor)
{
GridwiseEltwise::Run(p_a_global,
p_b_global,
p_c_global,
a_grid_desc_m0,
b_grid_desc_m0,
c_grid_desc_m0,
functor);
GridwiseBinEltwise::Run(p_a_global,
p_b_global,
p_c_global,
a_grid_desc_m0,
b_grid_desc_m0,
c_grid_desc_m0,
functor);
}
template <typename ADataType,
......@@ -39,7 +39,7 @@ template <typename ADataType,
index_t ThreadPerBlock,
index_t ThreadTileSize,
index_t ScalarPerVector>
struct GridwiseElementwise_1D
struct GridwiseBinaryElementwise_1D
{
static constexpr auto I0 = Number<0>{};
static constexpr int BlockTileSize = ThreadPerBlock * ThreadTileSize;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment