Commit d6dd2ddf authored by qisan's avatar qisan
Browse files

[Bugfix] Fix tvm_mmac not found error

parent 3a6a31c5
......@@ -286,6 +286,9 @@ TIR_DEFINE_TL_BUILTIN(tl_gemm_sp)
TIR_DEFINE_TL_BUILTIN(tvm_mfma).set_num_inputs(12).set_attr<TCallEffectKind>(
"TCallEffectKind", Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(tvm_mmac).set_num_inputs(12).set_attr<TCallEffectKind>(
"TCallEffectKind", Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(tvm_mfma_store)
.set_num_inputs(6)
.set_attr<TCallEffectKind>("TCallEffectKind",
......
......@@ -457,6 +457,17 @@ TVM_DLL const Op &loop_break();
*/
TVM_DLL const Op &tvm_mfma();
/*!
* \brief tvm intrinsic for amd matrix core mmac instructions.
*
* void tvm_mfma(StringImm shape, StringImm A_layout, StringImm B_layout,
* StringImm A_dtype, StringImm B_dtype, StringImm C_dtype,
* Var multiplicand_a, Expr a_index,
* Var multiplicand_b, Expr b_index,
* Var accumulator, Expr c_index);
*/
TVM_DLL const Op &tvm_mmac();
/*!
* \brief tvm intrinsic for storing the result of AMD MFMA into a destination
* pointer.
......
......@@ -1905,6 +1905,7 @@ vectorlow = _dtype_forward(_tir_op.vectorlow)
vectorhigh = _dtype_forward(_tir_op.vectorhigh)
vectorcombine = _dtype_forward(_tir_op.vectorcombine)
tvm_mfma = _dtype_forward(_tir_op.tvm_mfma)
tvm_mmac = _dtype_forward(_tir_op.tvm_mmac)
tvm_mfma_store = _dtype_forward(_tir_op.tvm_mfma_store)
tvm_rdna_wmma = _dtype_forward(_tir_op.tvm_rdna_wmma)
tvm_rdna_wmma_store = _dtype_forward(_tir_op.tvm_rdna_wmma_store)
......@@ -2165,6 +2166,7 @@ __all__ = [
"vectorhigh",
"vectorcombine",
"tvm_mfma",
"tvm_mmac",
"tvm_mfma_store",
"tvm_rdna_wmma",
"tvm_rdna_wmma_store",
......
......@@ -312,6 +312,7 @@ vectorlow = _dtype_forward(_tir_op.vectorlow)
vectorhigh = _dtype_forward(_tir_op.vectorhigh)
vectorcombine = _dtype_forward(_tir_op.vectorcombine)
tvm_mfma = _dtype_forward(_tir_op.tvm_mfma)
tvm_mmac = _dtype_forward(_tir_op.tvm_mmac)
tvm_mfma_store = _dtype_forward(_tir_op.tvm_mfma_store)
tvm_rdna_wmma = _dtype_forward(_tir_op.tvm_rdna_wmma)
tvm_rdna_wmma_store = _dtype_forward(_tir_op.tvm_rdna_wmma_store)
......@@ -1529,6 +1529,88 @@ def tvm_mfma(
)
def tvm_mmac(
dtype,
shape,
A_layout,
B_layout,
A_dtype,
B_dtype,
C_dtype,
multiplicand_a,
a_index,
multiplicand_b,
b_index,
accumulator,
c_index,
):
"""TVM intrinsic for amd matrix core mfma instructions
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-for-mma
Parameters
----------
dtype : str
The data type of the result.
shape : str
The shape of mma fragment.
A_layout : Literal["row", "col"]
The layout of multiplicand fragment A.
B_layout : Literal["row", "col"]
The layout of multiplicand fragment B.
A_dtype : str
The data type of multiplicand fragment A.
B_dtype : str
The data type of multiplicand fragment B.
C_dtype : str
The data type of accumulator fragment C.
multiplicand_a : Var
The multiplicand fragment A variable.
a_index : Expr
The index of multiplicand fragment A.
multiplicand_b : Var
The multiplicand fragment B variable.
b_index : Expr
The index of multiplicand fragment A.
accumulator : Var
The accumulator fragment C variable.
c_index : Expr
The index of accumulator fragment C.
Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin(
dtype,
_tvm_op.Op.get("tl.tvm_mmac"),
shape,
A_layout,
B_layout,
A_dtype,
B_dtype,
C_dtype,
multiplicand_a,
a_index,
multiplicand_b,
b_index,
accumulator,
c_index,
)
def tvm_mfma_store(dtype, m, n, dst_ptr, src_ptr, src_offset, dst_stride):
"""TVM intrinsic for storing the result of PTX MMA into a destination pointer
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment