Commit 680cfaa7 authored by rocking's avatar rocking
Browse files

Fix the meaning of broadcast dim parameter

parent 5d36f7a2
...@@ -223,7 +223,7 @@ void host_broadcast2D( ...@@ -223,7 +223,7 @@ void host_broadcast2D(
{ {
ComputeDataType Amn = static_cast<ComputeDataType>(A(m, n)); ComputeDataType Amn = static_cast<ComputeDataType>(A(m, n));
ComputeDataType Cmn = 0; ComputeDataType Cmn = 0;
if constexpr(broadcastDim == 1) if constexpr(broadcastDim == 0)
{ {
ComputeDataType Bn = static_cast<ComputeDataType>(B(n)); ComputeDataType Bn = static_cast<ComputeDataType>(B(n));
functor(Cmn, Amn, Bn); functor(Cmn, Amn, Bn);
...@@ -516,7 +516,7 @@ int main(int argc, char* argv[]) ...@@ -516,7 +516,7 @@ int main(int argc, char* argv[])
Tensor<CDataType>, Tensor<CDataType>,
EltwiseComputeDataType, EltwiseComputeDataType,
Sub_Exp, Sub_Exp,
1>(host_exp_m_n, c_m_n, c_n_max, M, N, Sub_Exp{}); 0>(host_exp_m_n, c_m_n, c_n_max, M, N, Sub_Exp{});
host_reduce_sum.Run(1, // alpha host_reduce_sum.Run(1, // alpha
reinterpret_cast<const HostReduceDataType*>(exp_m_n.mData.data()), reinterpret_cast<const HostReduceDataType*>(exp_m_n.mData.data()),
...@@ -529,7 +529,7 @@ int main(int argc, char* argv[]) ...@@ -529,7 +529,7 @@ int main(int argc, char* argv[])
Tensor<CDataType>, Tensor<CDataType>,
EltwiseComputeDataType, EltwiseComputeDataType,
Div, Div,
1>(host_softmax_m_n, exp_m_n, exp_n_sum, M, N, Div{}); 0>(host_softmax_m_n, exp_m_n, exp_n_sum, M, N, Div{});
bool result = true; bool result = true;
if(result &= ck::utils::check_err(c_m_n.mData, host_c_m_n.mData)) if(result &= ck::utils::check_err(c_m_n.mData, host_c_m_n.mData))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment