Commit b05a594e authored by rocking's avatar rocking
Browse files

Add reduce sum for denominator of softmax

parent 30348daa
......@@ -85,27 +85,53 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle
8>; // CBlockTransferScalarPerVector_NWaveNPerXdl
// clang-format on
constexpr int ReduceRank = 2;
constexpr int Rank = 2;
constexpr int NumReduceDim = 1;
constexpr ck::ReduceTensorOp ReduceMaxId = ck::ReduceTensorOp::MAX;
constexpr ck::ReduceTensorOp ReduceSumId = ck::ReduceTensorOp::ADD;
constexpr ck::NanPropagation NanOpt = ck::NanPropagation::PROPAGATE_NAN;
constexpr bool PropagateNan = (NanOpt == ck::NanPropagation::NOT_PROPAGATE_NAN) ? false : true;
// constexpr ck::ReduceTensorIndices_t IndicesOpt = ck::ReduceTensorIndices_t::NO_INDICES;
using ReduceMaxOp = typename ck::reduce_binary_operator<CDataType, ReduceMaxId>::opType;
using InElementwiseOperation =
using ReduceSumOp = typename ck::reduce_binary_operator<CDataType, ReduceSumId>::opType;
using ReduceMaxInElementwiseOperation =
typename ck::reduce_unary_operator<CDataType, ReduceMaxId, true, true>::InElementwiseOperation;
using AccElementwiseOperation =
using ReduceMaxAccElementwiseOperation =
typename ck::reduce_unary_operator<CDataType, ReduceMaxId, true, true>::AccElementwiseOperation;
using ReduceSumInElementwiseOperation =
typename ck::reduce_unary_operator<CDataType, ReduceSumId, true, true>::InElementwiseOperation;
using ReduceSumAccElementwiseOperation =
typename ck::reduce_unary_operator<CDataType, ReduceSumId, true, true>::AccElementwiseOperation;
using DeviceReduceMaxInstance =
ck::tensor_operation::device::DeviceReduceBlockWise<CDataType,
CDataType,
CDataType,
ReduceRank,
Rank,
NumReduceDim,
ReduceMaxOp,
InElementwiseOperation,
AccElementwiseOperation,
ReduceMaxInElementwiseOperation,
ReduceMaxAccElementwiseOperation,
PropagateNan,
false,
256,
4,
64,
1,
1,
0,
1,
1>;
using DeviceReduceSumInstance =
ck::tensor_operation::device::DeviceReduceBlockWise<CDataType,
CDataType,
CDataType,
Rank,
NumReduceDim,
ReduceSumOp,
ReduceSumInElementwiseOperation,
ReduceSumAccElementwiseOperation,
PropagateNan,
false,
256,
......@@ -119,7 +145,8 @@ using DeviceReduceMaxInstance =
struct Sub_Exp
{
__host__ __device__ constexpr void operator()(CDataType& dst, const CDataType& src1, const CDataType& src2) const
__host__ __device__ constexpr void
operator()(CDataType& dst, const CDataType& src1, const CDataType& src2) const
{
dst = src1 - src2;
// FIXME - use float16 exponential
......@@ -198,22 +225,25 @@ int main(int argc, char* argv[])
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<CDataType> c_m_n(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<int> c_m_n_max(std::vector<std::size_t>({static_cast<std::size_t>(N)}),
Tensor<int> c_n_max(std::vector<std::size_t>({static_cast<std::size_t>(N)}),
std::vector<std::size_t>({1}));
Tensor<CDataType> exp_m_n(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<CDataType> exp_n_sum(std::vector<std::size_t>({static_cast<std::size_t>(N)}),
std::vector<std::size_t>({1}));
Tensor<CDataType> d_m_n(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
const auto i_inLengths = ck::to_int_vector(c_m_n.mDesc.GetLengths());
const auto i_inStrides = ck::to_int_vector(c_m_n.mDesc.GetStrides());
const auto i_outLengths = ck::to_int_vector(c_m_n_max.mDesc.GetLengths());
const auto i_outStrides = ck::to_int_vector(c_m_n_max.mDesc.GetStrides());
const auto c_m_n_shape = ck::to_int_vector(c_m_n.mDesc.GetLengths());
const auto c_m_n_stride = ck::to_int_vector(c_m_n.mDesc.GetStrides());
const auto reduce_n_shape = ck::to_int_vector(c_n_max.mDesc.GetLengths());
const auto reduce_n_stride = ck::to_int_vector(c_n_max.mDesc.GetStrides());
size_t reduce_total_length = c_m_n.mDesc.GetElementSize() / c_m_n_max.mDesc.GetElementSize();
size_t reduce_total_length = c_m_n.mDesc.GetElementSize() / c_n_max.mDesc.GetElementSize();
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "c_m_n: " << c_m_n.mDesc << std::endl;
std::cout << "c_m_n_max: " << c_m_n_max.mDesc << std::endl;
std::cout << "d_m_n: " << d_m_n.mDesc << std::endl;
std::cout << "c_n_max: " << c_n_max.mDesc << std::endl;
std::cout << "exp_m_n: " << exp_m_n.mDesc << std::endl;
std::cout << "exp_n_sum: " << exp_n_sum.mDesc << std::endl;
switch(init_method)
{
......@@ -230,9 +260,10 @@ int main(int argc, char* argv[])
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n.mDesc.GetElementSpace());
DeviceMem c_m_n_max_device_buf(sizeof(CDataType) * c_m_n_max.mDesc.GetElementSpace());
DeviceMem c_m_n_max_indices_dev(0);
DeviceMem d_m_n_device_buf(sizeof(CDataType) * d_m_n.mDesc.GetElementSpace());
DeviceMem c_n_max_device_buf(sizeof(CDataType) * c_n_max.mDesc.GetElementSpace());
DeviceMem indices_device_buf(0);
DeviceMem exp_m_n_device_buf(sizeof(CDataType) * exp_m_n.mDesc.GetElementSpace());
DeviceMem exp_n_sum_device_buf(sizeof(CDataType) * exp_n_sum.mDesc.GetElementSpace());
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
b_k_n_device_buf.ToDevice(b_k_n.mData.data());
......@@ -265,23 +296,23 @@ int main(int argc, char* argv[])
// do reduce max
auto reduce_max = DeviceReduceMaxInstance{};
auto wsSizeInBytes = reduce_max.GetWorkspaceSizeInBytes(i_inLengths, reduceDims);
DeviceMem ws_dev(wsSizeInBytes);
auto reduce_max_workaspace_size = reduce_max.GetWorkspaceSizeInBytes(c_m_n_shape, reduceDims);
DeviceMem reduce_max_workaspace_device_buf(reduce_max_workaspace_size);
auto reduce_max_argument_ptr = reduce_max.MakeArgumentPointer(
i_inLengths,
i_inStrides,
i_outLengths,
i_outStrides,
c_m_n_shape,
c_m_n_stride,
reduce_n_shape,
reduce_n_stride,
reduceDims,
1,
0,
c_m_n_device_buf.GetDeviceBuffer(),
c_m_n_max_device_buf.GetDeviceBuffer(),
c_m_n_max_indices_dev.GetDeviceBuffer(),
ws_dev.GetDeviceBuffer(),
InElementwiseOperation{static_cast<int>(reduce_total_length)},
AccElementwiseOperation{static_cast<int>(reduce_total_length)});
c_n_max_device_buf.GetDeviceBuffer(),
indices_device_buf.GetDeviceBuffer(),
reduce_max_workaspace_device_buf.GetDeviceBuffer(),
ReduceMaxInElementwiseOperation{static_cast<int>(reduce_total_length)},
ReduceMaxAccElementwiseOperation{static_cast<int>(reduce_total_length)});
if(!reduce_max.IsSupportedArgument(reduce_max_argument_ptr.get()))
{
......@@ -292,12 +323,12 @@ int main(int argc, char* argv[])
auto reduce_max_invoker_ptr = reduce_max.MakeInvokerPointer();
reduce_max_invoker_ptr->Run(reduce_max_argument_ptr.get(), nrepeat);
// do broadcast sub
// do broadcast sub and exp
auto broadcastSubExp = DeviceElementwiseInstance{};
auto broadcastSubExp_argument_ptr =
broadcastSubExp.MakeArgumentPointer(c_m_n_device_buf.GetDeviceBuffer(),
c_m_n_max_device_buf.GetDeviceBuffer(),
d_m_n_device_buf.GetDeviceBuffer(),
c_n_max_device_buf.GetDeviceBuffer(),
exp_m_n_device_buf.GetDeviceBuffer(),
{M, N},
{StrideC, 1},
{0, 1},
......@@ -313,7 +344,36 @@ int main(int argc, char* argv[])
auto broadcastSubExp_invoker_ptr = broadcastSubExp.MakeInvokerPointer();
broadcastSubExp_invoker_ptr->Run(broadcastSubExp_argument_ptr.get(), nrepeat);
// TODO - Need BroadcastSub + exponential + ReduceSum + BroadcastDiv
// do reduce sum - denominator of softmax
auto reduce_sum = DeviceReduceSumInstance{};
auto reduce_sum_workaspace_size = reduce_sum.GetWorkspaceSizeInBytes(c_m_n_shape, reduceDims);
DeviceMem reduce_sum_workaspace_device_buf(reduce_sum_workaspace_size);
auto reduce_sum_argument_ptr = reduce_sum.MakeArgumentPointer(
c_m_n_shape,
c_m_n_stride,
reduce_n_shape,
reduce_n_stride,
reduceDims,
1,
0,
exp_m_n_device_buf.GetDeviceBuffer(),
exp_n_sum_device_buf.GetDeviceBuffer(),
indices_device_buf.GetDeviceBuffer(),
reduce_sum_workaspace_device_buf.GetDeviceBuffer(),
ReduceSumInElementwiseOperation{static_cast<int>(reduce_total_length)},
ReduceSumAccElementwiseOperation{static_cast<int>(reduce_total_length)});
if(!reduce_sum.IsSupportedArgument(reduce_sum_argument_ptr.get()))
{
throw std::runtime_error(
"The runtime parameters seems not supported by the DeviceReduce instance, exiting!");
};
auto reduce_sum_invoker_ptr = reduce_sum.MakeInvokerPointer();
reduce_sum_invoker_ptr->Run(reduce_sum_argument_ptr.get(), nrepeat);
// TODO - Need BroadcastDiv
// TODO = do_verification
(void)do_verification;
return 0;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment