Unverified Commit b6b02dab authored by Jiaxing Ding's avatar Jiaxing Ding Committed by GitHub
Browse files

[AMD] fix mfma op interface (#791)


Co-authored-by: default avatarJiaxing Ding <jiaxing.ding@bytedance.com>
parent cda5ea15
......@@ -1321,8 +1321,9 @@ def tvm_mfma(
call : PrimExpr
The call expression.
"""
return _tvm_op.tvm_mfma(
return call_intrin(
dtype,
_tvm_op.Op.get("tl.tvm_mfma"),
shape,
A_layout,
B_layout,
......@@ -1369,7 +1370,16 @@ def tvm_mfma_store(dtype, m, n, dst_ptr, src_ptr, src_offset, dst_stride):
call : PrimExpr
The call expression.
"""
return _tvm_op.tvm_mfma_store(dtype, m, n, dst_ptr, src_ptr, src_offset, dst_stride)
return call_intrin(
dtype,
_tvm_op.Op.get("tl.tvm_mfma_store"),
m,
n,
dst_ptr,
src_ptr,
src_offset,
dst_stride,
)
def tvm_rdna_wmma(
......@@ -1436,8 +1446,9 @@ def tvm_rdna_wmma(
call : PrimExpr
The call expression.
"""
return _tvm_op.tvm_rdna_wmma(
return call_intrin(
dtype,
_tvm_op.Op.get("tl.tvm_rdna_wmma"),
shape,
A_layout,
B_layout,
......@@ -1484,7 +1495,16 @@ def tvm_rdna_wmma_store(dtype, m, n, dst_ptr, src_ptr, src_offset, dst_stride):
call : PrimExpr
The call expression.
"""
return _tvm_op.tvm_rdna_wmma_store(dtype, m, n, dst_ptr, src_ptr, src_offset, dst_stride)
return call_intrin(
dtype,
_tvm_op.Op.get("tl.tvm_rdna_wmma_store"),
m,
n,
dst_ptr,
src_ptr,
src_offset,
dst_stride,
)
def ptx_cp_async_barrier(barrier_id):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment