Commit fe659502 authored by rocking's avatar rocking
Browse files

Add verication of softmax

parent dba65b1c
...@@ -5,11 +5,14 @@ ...@@ -5,11 +5,14 @@
#include <stdlib.h> #include <stdlib.h>
#include <half.hpp> #include <half.hpp>
#include <math.h> #include <math.h>
#include "check_err.hpp"
#include "config.hpp" #include "config.hpp"
#include "device.hpp" #include "device.hpp"
#include "host_tensor.hpp" #include "host_tensor.hpp"
#include "host_tensor_generator.hpp" #include "host_tensor_generator.hpp"
#include "host_reduce_util.hpp" #include "host_reduce_util.hpp"
#include "host_reduction.hpp"
#include "device_tensor.hpp" #include "device_tensor.hpp"
#include "device_gemm_xdl.hpp" #include "device_gemm_xdl.hpp"
#include "device_gemm_xdl_c_shuffle.hpp" #include "device_gemm_xdl_c_shuffle.hpp"
...@@ -89,9 +92,7 @@ constexpr int Rank = 2; ...@@ -89,9 +92,7 @@ constexpr int Rank = 2;
constexpr int NumReduceDim = 1; constexpr int NumReduceDim = 1;
constexpr ck::ReduceTensorOp ReduceMaxId = ck::ReduceTensorOp::MAX; constexpr ck::ReduceTensorOp ReduceMaxId = ck::ReduceTensorOp::MAX;
constexpr ck::ReduceTensorOp ReduceSumId = ck::ReduceTensorOp::ADD; constexpr ck::ReduceTensorOp ReduceSumId = ck::ReduceTensorOp::ADD;
constexpr ck::NanPropagation NanOpt = ck::NanPropagation::PROPAGATE_NAN; constexpr bool ReducePropagateNan = false;
constexpr bool PropagateNan = (NanOpt == ck::NanPropagation::NOT_PROPAGATE_NAN) ? false : true;
// constexpr ck::ReduceTensorIndices_t IndicesOpt = ck::ReduceTensorIndices_t::NO_INDICES;
using ReduceMaxOp = typename ck::reduce_binary_operator<CDataType, ReduceMaxId>::opType; using ReduceMaxOp = typename ck::reduce_binary_operator<CDataType, ReduceMaxId>::opType;
using ReduceSumOp = typename ck::reduce_binary_operator<CDataType, ReduceSumId>::opType; using ReduceSumOp = typename ck::reduce_binary_operator<CDataType, ReduceSumId>::opType;
using ReduceMaxInElementwiseOperation = using ReduceMaxInElementwiseOperation =
...@@ -112,7 +113,7 @@ using DeviceReduceMaxInstance = ...@@ -112,7 +113,7 @@ using DeviceReduceMaxInstance =
ReduceMaxOp, ReduceMaxOp,
ReduceMaxInElementwiseOperation, ReduceMaxInElementwiseOperation,
ReduceMaxAccElementwiseOperation, ReduceMaxAccElementwiseOperation,
PropagateNan, ReducePropagateNan,
false, false,
256, 256,
4, 4,
...@@ -132,7 +133,7 @@ using DeviceReduceSumInstance = ...@@ -132,7 +133,7 @@ using DeviceReduceSumInstance =
ReduceSumOp, ReduceSumOp,
ReduceSumInElementwiseOperation, ReduceSumInElementwiseOperation,
ReduceSumAccElementwiseOperation, ReduceSumAccElementwiseOperation,
PropagateNan, ReducePropagateNan,
false, false,
256, 256,
4, 4,
...@@ -170,9 +171,47 @@ using DeviceElementwiseSubExpInstance = ck::tensor_operation::device:: ...@@ -170,9 +171,47 @@ using DeviceElementwiseSubExpInstance = ck::tensor_operation::device::
using DeviceElementwiseDivInstance = ck::tensor_operation::device:: using DeviceElementwiseDivInstance = ck::tensor_operation::device::
DeviceElementwise_2D<CDataType, CDataType, CDataType, Div, 256, 32, 8>; DeviceElementwise_2D<CDataType, CDataType, CDataType, Div, 256, 32, 8>;
using ReferenceGemmInstance = ck::tensor_operation::host:: using HostGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, PassThrough, PassThrough, PassThrough>; ReferenceGemm<ADataType, BDataType, CDataType, PassThrough, PassThrough, PassThrough>;
using HostReduceMaxInstance = ReductionHost<CDataType,
CDataType,
CDataType,
ReduceMaxId,
Rank,
NumReduceDim,
ReducePropagateNan,
false>;
using HostReduceSumInstance = ReductionHost<CDataType,
CDataType,
CDataType,
ReduceSumId,
Rank,
NumReduceDim,
ReducePropagateNan,
false>;
template <typename HostTensorA,
typename HostTensorB,
typename HostTensorC,
typename Functor,
int broadcastDim>
void host_broadcast2D(
HostTensorC& C, const HostTensorA& A, const HostTensorB& B, int M, int N, Functor functor)
{
for(int m = 0; m < M; ++m)
{
for(int n = 0; n < N; ++n)
{
if constexpr(broadcastDim == 1)
functor(C(m, n), A(m, n), B(n));
else
functor(C(m, n), A(m, n), B(m));
}
}
}
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
bool do_verification = 0; bool do_verification = 0;
...@@ -189,7 +228,6 @@ int main(int argc, char* argv[]) ...@@ -189,7 +228,6 @@ int main(int argc, char* argv[])
ck::index_t StrideC = 4096; ck::index_t StrideC = 4096;
const std::vector<int> reduceDims{0}; const std::vector<int> reduceDims{0};
const std::vector<int> reduceInvariantDims{1};
if(argc == 4) if(argc == 4)
{ {
...@@ -237,7 +275,7 @@ int main(int argc, char* argv[]) ...@@ -237,7 +275,7 @@ int main(int argc, char* argv[])
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<CDataType> c_m_n(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); Tensor<CDataType> c_m_n(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<int> c_n_max(std::vector<std::size_t>({static_cast<std::size_t>(N)}), Tensor<CDataType> c_n_max(std::vector<std::size_t>({static_cast<std::size_t>(N)}),
std::vector<std::size_t>({1})); std::vector<std::size_t>({1}));
Tensor<CDataType> exp_m_n(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); Tensor<CDataType> exp_m_n(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<CDataType> exp_n_sum(std::vector<std::size_t>({static_cast<std::size_t>(N)}), Tensor<CDataType> exp_n_sum(std::vector<std::size_t>({static_cast<std::size_t>(N)}),
...@@ -370,8 +408,8 @@ int main(int argc, char* argv[]) ...@@ -370,8 +408,8 @@ int main(int argc, char* argv[])
reduce_n_shape, reduce_n_shape,
reduce_n_stride, reduce_n_stride,
reduceDims, reduceDims,
1, 1, // alpha
0, 0, // beta
exp_m_n_device_buf.GetDeviceBuffer(), exp_m_n_device_buf.GetDeviceBuffer(),
exp_n_sum_device_buf.GetDeviceBuffer(), exp_n_sum_device_buf.GetDeviceBuffer(),
indices_device_buf.GetDeviceBuffer(), indices_device_buf.GetDeviceBuffer(),
...@@ -410,6 +448,66 @@ int main(int argc, char* argv[]) ...@@ -410,6 +448,66 @@ int main(int argc, char* argv[])
broadcastDiv_invoker_ptr->Run(broadcastDiv_argument_ptr.get(), nrepeat); broadcastDiv_invoker_ptr->Run(broadcastDiv_argument_ptr.get(), nrepeat);
// TODO = do_verification // TODO = do_verification
(void)do_verification; if(do_verification)
{
std::cout << "verification..." << std::endl;
const std::vector<int> reduceInvariantDims{1};
Tensor<CDataType> host_c_m_n(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<CDataType> host_c_n_max(std::vector<std::size_t>({static_cast<std::size_t>(N)}),
std::vector<std::size_t>({1}));
Tensor<int> host_indices(host_c_n_max.mDesc.GetLengths());
Tensor<CDataType> host_exp_m_n(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<CDataType> host_exp_n_sum(std::vector<std::size_t>({static_cast<std::size_t>(N)}),
std::vector<std::size_t>({1}));
Tensor<CDataType> host_softmax_m_n(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
auto host_gemm = HostGemmInstance{};
auto host_gemm_invoker = host_gemm.MakeInvoker();
auto host_gemm_argument = host_gemm.MakeArgument(
a_m_k, b_k_n, host_c_m_n, PassThrough{}, PassThrough{}, PassThrough{});
auto host_reduce_max = HostReduceMaxInstance{
host_c_m_n.mDesc, host_c_n_max.mDesc, reduceInvariantDims, reduceDims};
auto host_reduce_sum = HostReduceSumInstance{
host_exp_m_n.mDesc, host_exp_n_sum.mDesc, reduceInvariantDims, reduceDims};
host_gemm_invoker.Run(host_gemm_argument);
host_reduce_max.Run(1, // alpha
reinterpret_cast<const CDataType*>(host_c_m_n.mData.data()),
0, // beta
reinterpret_cast<CDataType*>(host_c_n_max.mData.data()),
host_indices.mData.data());
host_broadcast2D<Tensor<CDataType>, Tensor<CDataType>, Tensor<CDataType>, Sub_Exp, 1>(
host_exp_m_n, host_c_m_n, host_c_n_max, M, N, Sub_Exp{});
host_reduce_sum.Run(1, // alpha
reinterpret_cast<const CDataType*>(host_exp_m_n.mData.data()),
0, // beta
reinterpret_cast<CDataType*>(host_exp_n_sum.mData.data()),
host_indices.mData.data());
host_broadcast2D<Tensor<CDataType>, Tensor<CDataType>, Tensor<CDataType>, Div, 1>(
host_softmax_m_n, host_exp_m_n, host_exp_n_sum, M, N, Div{});
c_m_n_device_buf.FromDevice(c_m_n.mData.data());
c_n_max_device_buf.FromDevice(c_n_max.mData.data());
exp_m_n_device_buf.FromDevice(exp_m_n.mData.data());
exp_n_sum_device_buf.FromDevice(exp_n_sum.mData.data());
softmax_m_n_device_buf.FromDevice(softmax_m_n.mData.data());
bool result = true;
if (result &= ck::utils::check_err(c_m_n.mData, host_c_m_n.mData))
std::cout << "[PASS] - c_m_n" << std::endl;
if (result &= ck::utils::check_err(c_n_max.mData, host_c_n_max.mData))
std::cout << "[PASS] - c_n_max" << std::endl;
if (result &= ck::utils::check_err(exp_m_n.mData, host_exp_m_n.mData))
std::cout << "[PASS] - exp_m_n" << std::endl;
if (result &= ck::utils::check_err(exp_n_sum.mData, host_exp_n_sum.mData))
std::cout << "[PASS] - exp_n_sum" << std::endl;
if (result &= ck::utils::check_err(softmax_m_n.mData, host_softmax_m_n.mData))
std::cout << "[PASS] - softmax_m_n" << std::endl;
}
return 0; return 0;
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment