Commit 9174b783 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

fixed a race condition problem in quant_gemm.

parent 277ae59c
...@@ -9,8 +9,6 @@ ...@@ -9,8 +9,6 @@
#include <migraphx/cpu/gemm.hpp> #include <migraphx/cpu/gemm.hpp>
#include <unordered_map> #include <unordered_map>
#include <utility> #include <utility>
#include <fstream>
#include <iomanip>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -565,8 +563,7 @@ struct cpu_quant_gemm ...@@ -565,8 +563,7 @@ struct cpu_quant_gemm
} }
// 2 input arguments // 2 input arguments
int32_t beta = 0; migemm(result, arg_0, arg_1, op.alpha, int32_t{0});
migemm(result, arg_0, arg_1, op.alpha, beta);
return result; return result;
} }
......
...@@ -69,6 +69,11 @@ void pack_b(hipStream_t stream, const argument& result, const argument& arg) ...@@ -69,6 +69,11 @@ void pack_b(hipStream_t stream, const argument& result, const argument& arg)
}); });
} }
void sync_stream(hipStream_t stream)
{
hipStreamSynchronize(stream);
}
} // namespace device } // namespace device
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -13,6 +13,9 @@ namespace device { ...@@ -13,6 +13,9 @@ namespace device {
void pack_a(hipStream_t stream, const argument& result, const argument& arg); void pack_a(hipStream_t stream, const argument& result, const argument& arg);
void pack_b(hipStream_t stream, const argument& result, const argument& arg); void pack_b(hipStream_t stream, const argument& result, const argument& arg);
void sync_stream(hipStream_t stream);
} // namespace device } // namespace device
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
#include <migraphx/gpu/device/pack.hpp> #include <migraphx/gpu/device/pack.hpp>
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <fstream>
#include <iomanip>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -86,6 +88,7 @@ argument miopen_quant_gemm::compute(context& ctx, ...@@ -86,6 +88,7 @@ argument miopen_quant_gemm::compute(context& ctx,
{ {
device::pack_b(ctx.get_stream().get(), args[arg_num - 3], args[0]); device::pack_b(ctx.get_stream().get(), args[arg_num - 3], args[0]);
} }
device::sync_stream(ctx.get_stream().get());
bool is_3inputs = (arg_num == 6); bool is_3inputs = (arg_num == 6);
int32_t beta = 0; int32_t beta = 0;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment