Commit 3e811ccf authored by rocking's avatar rocking
Browse files

Add device op for elementwise 2d

parent cbbc7e52
......@@ -19,6 +19,7 @@
#include "device_reduce_blockwise.hpp"
#include "reduction_enums.hpp"
#include "reduction_operator_mapping.hpp"
#include "device_elementwise_2d.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
......@@ -115,6 +116,17 @@ using DeviceReduceInstance =
1,
1>;
struct Sub
{
__host__ __device__ constexpr void operator()(F16& dst, const F16& src1, const F16& src2) const
{
dst = src1 - src2;
}
};
using DeviceElementwiseInstance =
ck::tensor_operation::device::DeviceElementwise_2D<CDataType, CDataType, CDataType, 256, Sub>;
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, PassThrough, PassThrough, PassThrough>;
......@@ -184,6 +196,7 @@ int main(int argc, char* argv[])
Tensor<CDataType> c_m_n(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<int> c_m_n_max(std::vector<std::size_t>({static_cast<std::size_t>(N)}),
std::vector<std::size_t>({1}));
Tensor<CDataType> d_m_n(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
const auto i_inLengths = ck::to_int_vector(c_m_n.mDesc.GetLengths());
const auto i_inStrides = ck::to_int_vector(c_m_n.mDesc.GetStrides());
......@@ -196,6 +209,7 @@ int main(int argc, char* argv[])
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "c_m_n: " << c_m_n.mDesc << std::endl;
std::cout << "c_m_n_max: " << c_m_n_max.mDesc << std::endl;
std::cout << "d_m_n: " << d_m_n.mDesc << std::endl;
switch(init_method)
{
......@@ -214,6 +228,7 @@ int main(int argc, char* argv[])
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n.mDesc.GetElementSpace());
DeviceMem c_m_n_max_device_buf(sizeof(CDataType) * c_m_n_max.mDesc.GetElementSpace());
DeviceMem c_m_n_max_indices_dev(0);
DeviceMem d_m_n_device_buf(sizeof(CDataType) * d_m_n.mDesc.GetElementSpace());
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
b_k_n_device_buf.ToDevice(b_k_n.mData.data());
......@@ -245,34 +260,54 @@ int main(int argc, char* argv[])
gemm_invoker.Run(gemm_argument, nrepeat);
// do reduce max
auto reduce = DeviceReduceInstance{};
auto wsSizeInBytes = reduce.GetWorkspaceSizeInBytes(i_inLengths, reduceDims);
auto reduce_max = DeviceReduceInstance{};
auto wsSizeInBytes = reduce_max.GetWorkspaceSizeInBytes(i_inLengths, reduceDims);
DeviceMem ws_dev(wsSizeInBytes);
auto argument_ptr =
reduce.MakeArgumentPointer(i_inLengths,
i_inStrides,
i_outLengths,
i_outStrides,
reduceDims,
1,
0,
c_m_n_device_buf.GetDeviceBuffer(),
c_m_n_max_device_buf.GetDeviceBuffer(),
c_m_n_max_indices_dev.GetDeviceBuffer(),
ws_dev.GetDeviceBuffer(),
InElementwiseOperation{static_cast<int>(reduce_total_length)},
AccElementwiseOperation{static_cast<int>(reduce_total_length)});
if(!reduce.IsSupportedArgument(argument_ptr.get()))
auto reduce_max_argument_ptr = reduce_max.MakeArgumentPointer(
i_inLengths,
i_inStrides,
i_outLengths,
i_outStrides,
reduceDims,
1,
0,
c_m_n_device_buf.GetDeviceBuffer(),
c_m_n_max_device_buf.GetDeviceBuffer(),
c_m_n_max_indices_dev.GetDeviceBuffer(),
ws_dev.GetDeviceBuffer(),
InElementwiseOperation{static_cast<int>(reduce_total_length)},
AccElementwiseOperation{static_cast<int>(reduce_total_length)});
if(!reduce_max.IsSupportedArgument(reduce_max_argument_ptr.get()))
{
std::cout
<< "The runtime parameters seems not supported by the DeviceReduce instance, exiting!"
<< std::endl;
throw std::runtime_error(
"The runtime parameters seems not supported by the DeviceReduce instance, exiting!");
};
auto reduce_max_invoker_ptr = reduce_max.MakeInvokerPointer();
reduce_max_invoker_ptr->Run(reduce_max_argument_ptr.get(), nrepeat);
// do broadcast sub
auto broadcastSub = DeviceElementwiseInstance{};
auto broadcastSub_argument_ptr =
broadcastSub.MakeArgumentPointer(c_m_n_device_buf.GetDeviceBuffer(),
c_m_n_max_device_buf.GetDeviceBuffer(),
d_m_n_device_buf.GetDeviceBuffer(),
{M, N},
{StrideC, 1},
{0, 1},
{StrideC, 1},
Sub{});
if(!broadcastSub.IsSupportedArgument(broadcastSub_argument_ptr.get()))
{
throw std::runtime_error(
"The runtime parameters seems not supported by the DeviceElementwise_2D instance, exiting!");
};
auto invoker_ptr = reduce.MakeInvokerPointer();
invoker_ptr->Run(argument_ptr.get(), nrepeat);
auto broadcastSub_invoker_ptr = broadcastSub.MakeInvokerPointer();
broadcastSub_invoker_ptr->Run(broadcastSub_argument_ptr.get(), nrepeat);
// TODO - Need BroadcastSub + exponential + ReduceSum + BroadcastDiv
// TODO = do_verification
......
#pragma once
#include <iostream>
#include <vector>
#include "device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename ElementwiseFunctor>
struct DeviceElementwise : public BaseOperator
{
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
const std::vector<int>& shape_a,
const std::vector<int>& stride_a,
const std::vector<int>& shape_b,
const std::vector<int>& stride_b,
ElementwiseFunctor functor) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
#pragma once
#include <iostream>
#include <vector>
#include "device.hpp"
#include "device_elementwise.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename ADataType,
typename BDataType,
typename CDataType,
index_t BlockSize,
typename ElementwiseFunctor>
struct DeviceElementwise_2D : public DeviceElementwise<ElementwiseFunctor>
{
static auto Make2dDescriptor_M_N(const std::vector<int>& shape, const std::vector<int>& stride)
{
return make_naive_tensor_descriptor(make_tuple(shape[0], shape[1]),
make_tuple(stride[0], stride[1]));
}
using GridDesc_M_N = decltype(Make2dDescriptor_M_N({1, 1}, {1, 1}));
struct Argument : public BaseArgument
{
Argument(const ADataType* p_a,
const BDataType* p_b,
CDataType* p_c,
const std::vector<int>& shape,
const std::vector<int>& stride_a,
const std::vector<int>& stride_b,
const std::vector<int>& stride_c,
ElementwiseFunctor functor)
: p_a_(p_a),
p_b_(p_b),
p_c_(p_c),
a_grid_desc_m_n_(Make2dDescriptor_M_N(shape, stride_a)),
b_grid_desc_m_n_(Make2dDescriptor_M_N(shape, stride_b)),
c_grid_desc_m_n_(Make2dDescriptor_M_N(shape, stride_c)),
functor_(functor)
{
}
const ADataType* p_a_;
const BDataType* p_b_;
CDataType* p_c_;
GridDesc_M_N a_grid_desc_m_n_;
GridDesc_M_N b_grid_desc_m_n_;
GridDesc_M_N c_grid_desc_m_n_;
ElementwiseFunctor functor_;
};
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, int nrepeat = 1)
{
// TODO
(void)arg;
(void)nrepeat;
return 0;
}
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
};
};
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
// TODO: properly implement this check
const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
return pArg != nullptr;
};
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
const std::vector<int>& shape,
const std::vector<int>& stride_a,
const std::vector<int>& stride_b,
const std::vector<int>& stride_c,
ElementwiseFunctor functor) override
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
static_cast<CDataType*>(p_c),
shape,
stride_a,
stride_b,
stride_c,
functor);
}
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceElementwise_2D"
<< "<"
<< BlockSize
<< ">";
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment