Commit 70c70d6c authored by Andriy Roshchenko's avatar Andriy Roshchenko
Browse files

Add synchronization into MFMA kernels

parent f1f36a61
......@@ -643,7 +643,9 @@ __global__ void matmul(const AType* a, const BType* b, CType* c)
// Matrix multiply-accumulate using MFMA units
// Accumulation intermediate = BLOCK_M x BLOCK_N
__syncthreads();
mfma_type_selector<AFragT, BFragT, AccumFragT, BLOCK_M, BLOCK_N>{}(fragA, fragB, fragAcc);
__syncthreads();
for(int i = 0; i < vectorSize(fragC); ++i)
{
......@@ -696,8 +698,10 @@ matmul(const AType* a, const ScaleType* xa, const BType* b, const ScaleType* xb,
// Scaled Matrix multiply-accumulate using MFMA units
// Accumulation intermediate = BLOCK_M x BLOCK_N
__syncthreads();
mfma_type_selector<AFragT, BFragT, AccumFragT, BLOCK_M, BLOCK_N>{}(
fragA, fragXa, fragB, fragXb, fragAcc);
__syncthreads();
for(int i = 0; i < vectorSize(fragC); ++i)
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment