Commit b05a594e authored by rocking's avatar rocking
Browse files

Add reduce sum for denominator of softmax

parent 30348daa
...@@ -85,27 +85,53 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle ...@@ -85,27 +85,53 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle
8>; // CBlockTransferScalarPerVector_NWaveNPerXdl 8>; // CBlockTransferScalarPerVector_NWaveNPerXdl
// clang-format on // clang-format on
constexpr int ReduceRank = 2; constexpr int Rank = 2;
constexpr int NumReduceDim = 1; constexpr int NumReduceDim = 1;
constexpr ck::ReduceTensorOp ReduceMaxId = ck::ReduceTensorOp::MAX; constexpr ck::ReduceTensorOp ReduceMaxId = ck::ReduceTensorOp::MAX;
constexpr ck::NanPropagation NanOpt = ck::NanPropagation::PROPAGATE_NAN; constexpr ck::ReduceTensorOp ReduceSumId = ck::ReduceTensorOp::ADD;
constexpr ck::NanPropagation NanOpt = ck::NanPropagation::PROPAGATE_NAN;
constexpr bool PropagateNan = (NanOpt == ck::NanPropagation::NOT_PROPAGATE_NAN) ? false : true; constexpr bool PropagateNan = (NanOpt == ck::NanPropagation::NOT_PROPAGATE_NAN) ? false : true;
// constexpr ck::ReduceTensorIndices_t IndicesOpt = ck::ReduceTensorIndices_t::NO_INDICES; // constexpr ck::ReduceTensorIndices_t IndicesOpt = ck::ReduceTensorIndices_t::NO_INDICES;
using ReduceMaxOp = typename ck::reduce_binary_operator<CDataType, ReduceMaxId>::opType; using ReduceMaxOp = typename ck::reduce_binary_operator<CDataType, ReduceMaxId>::opType;
using InElementwiseOperation = using ReduceSumOp = typename ck::reduce_binary_operator<CDataType, ReduceSumId>::opType;
using ReduceMaxInElementwiseOperation =
typename ck::reduce_unary_operator<CDataType, ReduceMaxId, true, true>::InElementwiseOperation; typename ck::reduce_unary_operator<CDataType, ReduceMaxId, true, true>::InElementwiseOperation;
using AccElementwiseOperation = using ReduceMaxAccElementwiseOperation =
typename ck::reduce_unary_operator<CDataType, ReduceMaxId, true, true>::AccElementwiseOperation; typename ck::reduce_unary_operator<CDataType, ReduceMaxId, true, true>::AccElementwiseOperation;
using ReduceSumInElementwiseOperation =
typename ck::reduce_unary_operator<CDataType, ReduceSumId, true, true>::InElementwiseOperation;
using ReduceSumAccElementwiseOperation =
typename ck::reduce_unary_operator<CDataType, ReduceSumId, true, true>::AccElementwiseOperation;
using DeviceReduceMaxInstance = using DeviceReduceMaxInstance =
ck::tensor_operation::device::DeviceReduceBlockWise<CDataType, ck::tensor_operation::device::DeviceReduceBlockWise<CDataType,
CDataType, CDataType,
CDataType, CDataType,
ReduceRank, Rank,
NumReduceDim, NumReduceDim,
ReduceMaxOp, ReduceMaxOp,
InElementwiseOperation, ReduceMaxInElementwiseOperation,
AccElementwiseOperation, ReduceMaxAccElementwiseOperation,
PropagateNan,
false,
256,
4,
64,
1,
1,
0,
1,
1>;
using DeviceReduceSumInstance =
ck::tensor_operation::device::DeviceReduceBlockWise<CDataType,
CDataType,
CDataType,
Rank,
NumReduceDim,
ReduceSumOp,
ReduceSumInElementwiseOperation,
ReduceSumAccElementwiseOperation,
PropagateNan, PropagateNan,
false, false,
256, 256,
...@@ -119,12 +145,13 @@ using DeviceReduceMaxInstance = ...@@ -119,12 +145,13 @@ using DeviceReduceMaxInstance =
struct Sub_Exp struct Sub_Exp
{ {
__host__ __device__ constexpr void operator()(CDataType& dst, const CDataType& src1, const CDataType& src2) const __host__ __device__ constexpr void
operator()(CDataType& dst, const CDataType& src1, const CDataType& src2) const
{ {
dst = src1 - src2; dst = src1 - src2;
// FIXME - use float16 exponential // FIXME - use float16 exponential
float dst_f32 = static_cast<float>(dst); float dst_f32 = static_cast<float>(dst);
dst = static_cast<CDataType>(exp(dst_f32)); dst = static_cast<CDataType>(exp(dst_f32));
} }
}; };
...@@ -198,22 +225,25 @@ int main(int argc, char* argv[]) ...@@ -198,22 +225,25 @@ int main(int argc, char* argv[])
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<CDataType> c_m_n(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); Tensor<CDataType> c_m_n(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<int> c_m_n_max(std::vector<std::size_t>({static_cast<std::size_t>(N)}), Tensor<int> c_n_max(std::vector<std::size_t>({static_cast<std::size_t>(N)}),
std::vector<std::size_t>({1})); std::vector<std::size_t>({1}));
Tensor<CDataType> d_m_n(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); Tensor<CDataType> exp_m_n(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<CDataType> exp_n_sum(std::vector<std::size_t>({static_cast<std::size_t>(N)}),
std::vector<std::size_t>({1}));
const auto i_inLengths = ck::to_int_vector(c_m_n.mDesc.GetLengths()); const auto c_m_n_shape = ck::to_int_vector(c_m_n.mDesc.GetLengths());
const auto i_inStrides = ck::to_int_vector(c_m_n.mDesc.GetStrides()); const auto c_m_n_stride = ck::to_int_vector(c_m_n.mDesc.GetStrides());
const auto i_outLengths = ck::to_int_vector(c_m_n_max.mDesc.GetLengths()); const auto reduce_n_shape = ck::to_int_vector(c_n_max.mDesc.GetLengths());
const auto i_outStrides = ck::to_int_vector(c_m_n_max.mDesc.GetStrides()); const auto reduce_n_stride = ck::to_int_vector(c_n_max.mDesc.GetStrides());
size_t reduce_total_length = c_m_n.mDesc.GetElementSize() / c_m_n_max.mDesc.GetElementSize(); size_t reduce_total_length = c_m_n.mDesc.GetElementSize() / c_n_max.mDesc.GetElementSize();
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "c_m_n: " << c_m_n.mDesc << std::endl; std::cout << "c_m_n: " << c_m_n.mDesc << std::endl;
std::cout << "c_m_n_max: " << c_m_n_max.mDesc << std::endl; std::cout << "c_n_max: " << c_n_max.mDesc << std::endl;
std::cout << "d_m_n: " << d_m_n.mDesc << std::endl; std::cout << "exp_m_n: " << exp_m_n.mDesc << std::endl;
std::cout << "exp_n_sum: " << exp_n_sum.mDesc << std::endl;
switch(init_method) switch(init_method)
{ {
...@@ -230,9 +260,10 @@ int main(int argc, char* argv[]) ...@@ -230,9 +260,10 @@ int main(int argc, char* argv[])
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n.mDesc.GetElementSpace()); DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n.mDesc.GetElementSpace());
DeviceMem c_m_n_max_device_buf(sizeof(CDataType) * c_m_n_max.mDesc.GetElementSpace()); DeviceMem c_n_max_device_buf(sizeof(CDataType) * c_n_max.mDesc.GetElementSpace());
DeviceMem c_m_n_max_indices_dev(0); DeviceMem indices_device_buf(0);
DeviceMem d_m_n_device_buf(sizeof(CDataType) * d_m_n.mDesc.GetElementSpace()); DeviceMem exp_m_n_device_buf(sizeof(CDataType) * exp_m_n.mDesc.GetElementSpace());
DeviceMem exp_n_sum_device_buf(sizeof(CDataType) * exp_n_sum.mDesc.GetElementSpace());
a_m_k_device_buf.ToDevice(a_m_k.mData.data()); a_m_k_device_buf.ToDevice(a_m_k.mData.data());
b_k_n_device_buf.ToDevice(b_k_n.mData.data()); b_k_n_device_buf.ToDevice(b_k_n.mData.data());
...@@ -265,23 +296,23 @@ int main(int argc, char* argv[]) ...@@ -265,23 +296,23 @@ int main(int argc, char* argv[])
// do reduce max // do reduce max
auto reduce_max = DeviceReduceMaxInstance{}; auto reduce_max = DeviceReduceMaxInstance{};
auto wsSizeInBytes = reduce_max.GetWorkspaceSizeInBytes(i_inLengths, reduceDims); auto reduce_max_workaspace_size = reduce_max.GetWorkspaceSizeInBytes(c_m_n_shape, reduceDims);
DeviceMem ws_dev(wsSizeInBytes); DeviceMem reduce_max_workaspace_device_buf(reduce_max_workaspace_size);
auto reduce_max_argument_ptr = reduce_max.MakeArgumentPointer( auto reduce_max_argument_ptr = reduce_max.MakeArgumentPointer(
i_inLengths, c_m_n_shape,
i_inStrides, c_m_n_stride,
i_outLengths, reduce_n_shape,
i_outStrides, reduce_n_stride,
reduceDims, reduceDims,
1, 1,
0, 0,
c_m_n_device_buf.GetDeviceBuffer(), c_m_n_device_buf.GetDeviceBuffer(),
c_m_n_max_device_buf.GetDeviceBuffer(), c_n_max_device_buf.GetDeviceBuffer(),
c_m_n_max_indices_dev.GetDeviceBuffer(), indices_device_buf.GetDeviceBuffer(),
ws_dev.GetDeviceBuffer(), reduce_max_workaspace_device_buf.GetDeviceBuffer(),
InElementwiseOperation{static_cast<int>(reduce_total_length)}, ReduceMaxInElementwiseOperation{static_cast<int>(reduce_total_length)},
AccElementwiseOperation{static_cast<int>(reduce_total_length)}); ReduceMaxAccElementwiseOperation{static_cast<int>(reduce_total_length)});
if(!reduce_max.IsSupportedArgument(reduce_max_argument_ptr.get())) if(!reduce_max.IsSupportedArgument(reduce_max_argument_ptr.get()))
{ {
...@@ -292,17 +323,17 @@ int main(int argc, char* argv[]) ...@@ -292,17 +323,17 @@ int main(int argc, char* argv[])
auto reduce_max_invoker_ptr = reduce_max.MakeInvokerPointer(); auto reduce_max_invoker_ptr = reduce_max.MakeInvokerPointer();
reduce_max_invoker_ptr->Run(reduce_max_argument_ptr.get(), nrepeat); reduce_max_invoker_ptr->Run(reduce_max_argument_ptr.get(), nrepeat);
// do broadcast sub // do broadcast sub and exp
auto broadcastSubExp = DeviceElementwiseInstance{}; auto broadcastSubExp = DeviceElementwiseInstance{};
auto broadcastSubExp_argument_ptr = auto broadcastSubExp_argument_ptr =
broadcastSubExp.MakeArgumentPointer(c_m_n_device_buf.GetDeviceBuffer(), broadcastSubExp.MakeArgumentPointer(c_m_n_device_buf.GetDeviceBuffer(),
c_m_n_max_device_buf.GetDeviceBuffer(), c_n_max_device_buf.GetDeviceBuffer(),
d_m_n_device_buf.GetDeviceBuffer(), exp_m_n_device_buf.GetDeviceBuffer(),
{M, N}, {M, N},
{StrideC, 1}, {StrideC, 1},
{0, 1}, {0, 1},
{StrideC, 1}, {StrideC, 1},
Sub_Exp{}); Sub_Exp{});
if(!broadcastSubExp.IsSupportedArgument(broadcastSubExp_argument_ptr.get())) if(!broadcastSubExp.IsSupportedArgument(broadcastSubExp_argument_ptr.get()))
{ {
...@@ -313,7 +344,36 @@ int main(int argc, char* argv[]) ...@@ -313,7 +344,36 @@ int main(int argc, char* argv[])
auto broadcastSubExp_invoker_ptr = broadcastSubExp.MakeInvokerPointer(); auto broadcastSubExp_invoker_ptr = broadcastSubExp.MakeInvokerPointer();
broadcastSubExp_invoker_ptr->Run(broadcastSubExp_argument_ptr.get(), nrepeat); broadcastSubExp_invoker_ptr->Run(broadcastSubExp_argument_ptr.get(), nrepeat);
// TODO - Need BroadcastSub + exponential + ReduceSum + BroadcastDiv // do reduce sum - denominator of softmax
auto reduce_sum = DeviceReduceSumInstance{};
auto reduce_sum_workaspace_size = reduce_sum.GetWorkspaceSizeInBytes(c_m_n_shape, reduceDims);
DeviceMem reduce_sum_workaspace_device_buf(reduce_sum_workaspace_size);
auto reduce_sum_argument_ptr = reduce_sum.MakeArgumentPointer(
c_m_n_shape,
c_m_n_stride,
reduce_n_shape,
reduce_n_stride,
reduceDims,
1,
0,
exp_m_n_device_buf.GetDeviceBuffer(),
exp_n_sum_device_buf.GetDeviceBuffer(),
indices_device_buf.GetDeviceBuffer(),
reduce_sum_workaspace_device_buf.GetDeviceBuffer(),
ReduceSumInElementwiseOperation{static_cast<int>(reduce_total_length)},
ReduceSumAccElementwiseOperation{static_cast<int>(reduce_total_length)});
if(!reduce_sum.IsSupportedArgument(reduce_sum_argument_ptr.get()))
{
throw std::runtime_error(
"The runtime parameters seems not supported by the DeviceReduce instance, exiting!");
};
auto reduce_sum_invoker_ptr = reduce_sum.MakeInvokerPointer();
reduce_sum_invoker_ptr->Run(reduce_sum_argument_ptr.get(), nrepeat);
// TODO - Need BroadcastDiv
// TODO = do_verification // TODO = do_verification
(void)do_verification; (void)do_verification;
return 0; return 0;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment