Commit 680cfaa7 authored by rocking's avatar rocking
Browse files

Fix the meaning of broadcast dim parameter

parent 5d36f7a2
......@@ -223,7 +223,7 @@ void host_broadcast2D(
{
ComputeDataType Amn = static_cast<ComputeDataType>(A(m, n));
ComputeDataType Cmn = 0;
if constexpr(broadcastDim == 1)
if constexpr(broadcastDim == 0)
{
ComputeDataType Bn = static_cast<ComputeDataType>(B(n));
functor(Cmn, Amn, Bn);
......@@ -516,7 +516,7 @@ int main(int argc, char* argv[])
Tensor<CDataType>,
EltwiseComputeDataType,
Sub_Exp,
1>(host_exp_m_n, c_m_n, c_n_max, M, N, Sub_Exp{});
0>(host_exp_m_n, c_m_n, c_n_max, M, N, Sub_Exp{});
host_reduce_sum.Run(1, // alpha
reinterpret_cast<const HostReduceDataType*>(exp_m_n.mData.data()),
......@@ -529,7 +529,7 @@ int main(int argc, char* argv[])
Tensor<CDataType>,
EltwiseComputeDataType,
Div,
1>(host_softmax_m_n, exp_m_n, exp_n_sum, M, N, Div{});
0>(host_softmax_m_n, exp_m_n, exp_n_sum, M, N, Div{});
bool result = true;
if(result &= ck::utils::check_err(c_m_n.mData, host_c_m_n.mData))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment