Commit 21802fda authored by rocking's avatar rocking
Browse files

[What] Sync input of each host kernel and device kernel

[Why] Prevent error propogation
parent e83b22e0
...@@ -455,6 +455,13 @@ int main(int argc, char* argv[]) ...@@ -455,6 +455,13 @@ int main(int argc, char* argv[])
if(do_verification) if(do_verification)
{ {
std::cout << "verification..." << std::endl; std::cout << "verification..." << std::endl;
c_m_n_device_buf.FromDevice(c_m_n.mData.data());
c_n_max_device_buf.FromDevice(c_n_max.mData.data());
exp_m_n_device_buf.FromDevice(exp_m_n.mData.data());
exp_n_sum_device_buf.FromDevice(exp_n_sum.mData.data());
softmax_m_n_device_buf.FromDevice(softmax_m_n.mData.data());
const std::vector<int> reduceInvariantDims{1}; const std::vector<int> reduceInvariantDims{1};
Tensor<CDataType> host_c_m_n(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); Tensor<CDataType> host_c_m_n(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<CDataType> host_c_n_max(std::vector<std::size_t>({static_cast<std::size_t>(N)}), Tensor<CDataType> host_c_n_max(std::vector<std::size_t>({static_cast<std::size_t>(N)}),
...@@ -478,28 +485,22 @@ int main(int argc, char* argv[]) ...@@ -478,28 +485,22 @@ int main(int argc, char* argv[])
host_gemm_invoker.Run(host_gemm_argument); host_gemm_invoker.Run(host_gemm_argument);
host_reduce_max.Run(1, // alpha host_reduce_max.Run(1, // alpha
reinterpret_cast<const HostReduceDataType*>(host_c_m_n.mData.data()), reinterpret_cast<const HostReduceDataType*>(c_m_n.mData.data()),
0, // beta 0, // beta
reinterpret_cast<HostReduceDataType*>(host_c_n_max.mData.data()), reinterpret_cast<HostReduceDataType*>(host_c_n_max.mData.data()),
host_indices.mData.data()); host_indices.mData.data());
host_broadcast2D<Tensor<CDataType>, Tensor<CDataType>, Tensor<CDataType>, Sub_Exp, 1>( host_broadcast2D<Tensor<CDataType>, Tensor<CDataType>, Tensor<CDataType>, Sub_Exp, 1>(
host_exp_m_n, host_c_m_n, host_c_n_max, M, N, Sub_Exp{}); host_exp_m_n, c_m_n, c_n_max, M, N, Sub_Exp{});
host_reduce_sum.Run(1, // alpha host_reduce_sum.Run(1, // alpha
reinterpret_cast<const HostReduceDataType*>(host_exp_m_n.mData.data()), reinterpret_cast<const HostReduceDataType*>(exp_m_n.mData.data()),
0, // beta 0, // beta
reinterpret_cast<HostReduceDataType*>(host_exp_n_sum.mData.data()), reinterpret_cast<HostReduceDataType*>(host_exp_n_sum.mData.data()),
host_indices.mData.data()); host_indices.mData.data());
host_broadcast2D<Tensor<CDataType>, Tensor<CDataType>, Tensor<CDataType>, Div, 1>( host_broadcast2D<Tensor<CDataType>, Tensor<CDataType>, Tensor<CDataType>, Div, 1>(
host_softmax_m_n, host_exp_m_n, host_exp_n_sum, M, N, Div{}); host_softmax_m_n, exp_m_n, exp_n_sum, M, N, Div{});
c_m_n_device_buf.FromDevice(c_m_n.mData.data());
c_n_max_device_buf.FromDevice(c_n_max.mData.data());
exp_m_n_device_buf.FromDevice(exp_m_n.mData.data());
exp_n_sum_device_buf.FromDevice(exp_n_sum.mData.data());
softmax_m_n_device_buf.FromDevice(softmax_m_n.mData.data());
bool result = true; bool result = true;
if(result &= ck::utils::check_err(c_m_n.mData, host_c_m_n.mData)) if(result &= ck::utils::check_err(c_m_n.mData, host_c_m_n.mData))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment