Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
d6dd2ddf
"...composable_kernel_rocm.git" did not exist on "47e523ef6b8220eac13e578d44115d7518afe731"
Commit
d6dd2ddf
authored
Dec 18, 2025
by
qisan
Browse files
[Bugfix] Fix tvm_mmac not found error
parent
3a6a31c5
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
99 additions
and
0 deletions
+99
-0
src/op/builtin.cc
src/op/builtin.cc
+3
-0
src/op/builtin.h
src/op/builtin.h
+11
-0
tilelang/language/ast/ir.py
tilelang/language/ast/ir.py
+2
-0
tilelang/language/tir/ir.py
tilelang/language/tir/ir.py
+1
-0
tilelang/language/tir/op.py
tilelang/language/tir/op.py
+82
-0
No files found.
src/op/builtin.cc
View file @
d6dd2ddf
...
@@ -286,6 +286,9 @@ TIR_DEFINE_TL_BUILTIN(tl_gemm_sp)
...
@@ -286,6 +286,9 @@ TIR_DEFINE_TL_BUILTIN(tl_gemm_sp)
TIR_DEFINE_TL_BUILTIN
(
tvm_mfma
).
set_num_inputs
(
12
).
set_attr
<
TCallEffectKind
>
(
TIR_DEFINE_TL_BUILTIN
(
tvm_mfma
).
set_num_inputs
(
12
).
set_attr
<
TCallEffectKind
>
(
"TCallEffectKind"
,
Integer
(
CallEffectKind
::
kOpaque
));
"TCallEffectKind"
,
Integer
(
CallEffectKind
::
kOpaque
));
TIR_DEFINE_TL_BUILTIN
(
tvm_mmac
).
set_num_inputs
(
12
).
set_attr
<
TCallEffectKind
>
(
"TCallEffectKind"
,
Integer
(
CallEffectKind
::
kOpaque
));
TIR_DEFINE_TL_BUILTIN
(
tvm_mfma_store
)
TIR_DEFINE_TL_BUILTIN
(
tvm_mfma_store
)
.
set_num_inputs
(
6
)
.
set_num_inputs
(
6
)
.
set_attr
<
TCallEffectKind
>
(
"TCallEffectKind"
,
.
set_attr
<
TCallEffectKind
>
(
"TCallEffectKind"
,
...
...
src/op/builtin.h
View file @
d6dd2ddf
...
@@ -457,6 +457,17 @@ TVM_DLL const Op &loop_break();
...
@@ -457,6 +457,17 @@ TVM_DLL const Op &loop_break();
*/
*/
TVM_DLL
const
Op
&
tvm_mfma
();
TVM_DLL
const
Op
&
tvm_mfma
();
/*!
* \brief tvm intrinsic for amd matrix core mmac instructions.
*
* void tvm_mfma(StringImm shape, StringImm A_layout, StringImm B_layout,
* StringImm A_dtype, StringImm B_dtype, StringImm C_dtype,
* Var multiplicand_a, Expr a_index,
* Var multiplicand_b, Expr b_index,
* Var accumulator, Expr c_index);
*/
TVM_DLL
const
Op
&
tvm_mmac
();
/*!
/*!
* \brief tvm intrinsic for storing the result of AMD MFMA into a destination
* \brief tvm intrinsic for storing the result of AMD MFMA into a destination
* pointer.
* pointer.
...
...
tilelang/language/ast/ir.py
View file @
d6dd2ddf
...
@@ -1905,6 +1905,7 @@ vectorlow = _dtype_forward(_tir_op.vectorlow)
...
@@ -1905,6 +1905,7 @@ vectorlow = _dtype_forward(_tir_op.vectorlow)
vectorhigh
=
_dtype_forward
(
_tir_op
.
vectorhigh
)
vectorhigh
=
_dtype_forward
(
_tir_op
.
vectorhigh
)
vectorcombine
=
_dtype_forward
(
_tir_op
.
vectorcombine
)
vectorcombine
=
_dtype_forward
(
_tir_op
.
vectorcombine
)
tvm_mfma
=
_dtype_forward
(
_tir_op
.
tvm_mfma
)
tvm_mfma
=
_dtype_forward
(
_tir_op
.
tvm_mfma
)
tvm_mmac
=
_dtype_forward
(
_tir_op
.
tvm_mmac
)
tvm_mfma_store
=
_dtype_forward
(
_tir_op
.
tvm_mfma_store
)
tvm_mfma_store
=
_dtype_forward
(
_tir_op
.
tvm_mfma_store
)
tvm_rdna_wmma
=
_dtype_forward
(
_tir_op
.
tvm_rdna_wmma
)
tvm_rdna_wmma
=
_dtype_forward
(
_tir_op
.
tvm_rdna_wmma
)
tvm_rdna_wmma_store
=
_dtype_forward
(
_tir_op
.
tvm_rdna_wmma_store
)
tvm_rdna_wmma_store
=
_dtype_forward
(
_tir_op
.
tvm_rdna_wmma_store
)
...
@@ -2165,6 +2166,7 @@ __all__ = [
...
@@ -2165,6 +2166,7 @@ __all__ = [
"vectorhigh"
,
"vectorhigh"
,
"vectorcombine"
,
"vectorcombine"
,
"tvm_mfma"
,
"tvm_mfma"
,
"tvm_mmac"
,
"tvm_mfma_store"
,
"tvm_mfma_store"
,
"tvm_rdna_wmma"
,
"tvm_rdna_wmma"
,
"tvm_rdna_wmma_store"
,
"tvm_rdna_wmma_store"
,
...
...
tilelang/language/tir/ir.py
View file @
d6dd2ddf
...
@@ -312,6 +312,7 @@ vectorlow = _dtype_forward(_tir_op.vectorlow)
...
@@ -312,6 +312,7 @@ vectorlow = _dtype_forward(_tir_op.vectorlow)
vectorhigh
=
_dtype_forward
(
_tir_op
.
vectorhigh
)
vectorhigh
=
_dtype_forward
(
_tir_op
.
vectorhigh
)
vectorcombine
=
_dtype_forward
(
_tir_op
.
vectorcombine
)
vectorcombine
=
_dtype_forward
(
_tir_op
.
vectorcombine
)
tvm_mfma
=
_dtype_forward
(
_tir_op
.
tvm_mfma
)
tvm_mfma
=
_dtype_forward
(
_tir_op
.
tvm_mfma
)
tvm_mmac
=
_dtype_forward
(
_tir_op
.
tvm_mmac
)
tvm_mfma_store
=
_dtype_forward
(
_tir_op
.
tvm_mfma_store
)
tvm_mfma_store
=
_dtype_forward
(
_tir_op
.
tvm_mfma_store
)
tvm_rdna_wmma
=
_dtype_forward
(
_tir_op
.
tvm_rdna_wmma
)
tvm_rdna_wmma
=
_dtype_forward
(
_tir_op
.
tvm_rdna_wmma
)
tvm_rdna_wmma_store
=
_dtype_forward
(
_tir_op
.
tvm_rdna_wmma_store
)
tvm_rdna_wmma_store
=
_dtype_forward
(
_tir_op
.
tvm_rdna_wmma_store
)
tilelang/language/tir/op.py
View file @
d6dd2ddf
...
@@ -1529,6 +1529,88 @@ def tvm_mfma(
...
@@ -1529,6 +1529,88 @@ def tvm_mfma(
)
)
def
tvm_mmac
(
dtype
,
shape
,
A_layout
,
B_layout
,
A_dtype
,
B_dtype
,
C_dtype
,
multiplicand_a
,
a_index
,
multiplicand_b
,
b_index
,
accumulator
,
c_index
,
):
"""TVM intrinsic for amd matrix core mfma instructions
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-for-mma
Parameters
----------
dtype : str
The data type of the result.
shape : str
The shape of mma fragment.
A_layout : Literal["row", "col"]
The layout of multiplicand fragment A.
B_layout : Literal["row", "col"]
The layout of multiplicand fragment B.
A_dtype : str
The data type of multiplicand fragment A.
B_dtype : str
The data type of multiplicand fragment B.
C_dtype : str
The data type of accumulator fragment C.
multiplicand_a : Var
The multiplicand fragment A variable.
a_index : Expr
The index of multiplicand fragment A.
multiplicand_b : Var
The multiplicand fragment B variable.
b_index : Expr
The index of multiplicand fragment A.
accumulator : Var
The accumulator fragment C variable.
c_index : Expr
The index of accumulator fragment C.
Returns
-------
call : PrimExpr
The call expression.
"""
return
call_intrin
(
dtype
,
_tvm_op
.
Op
.
get
(
"tl.tvm_mmac"
),
shape
,
A_layout
,
B_layout
,
A_dtype
,
B_dtype
,
C_dtype
,
multiplicand_a
,
a_index
,
multiplicand_b
,
b_index
,
accumulator
,
c_index
,
)
def
tvm_mfma_store
(
dtype
,
m
,
n
,
dst_ptr
,
src_ptr
,
src_offset
,
dst_stride
):
def
tvm_mfma_store
(
dtype
,
m
,
n
,
dst_ptr
,
src_ptr
,
src_offset
,
dst_stride
):
"""TVM intrinsic for storing the result of PTX MMA into a destination pointer
"""TVM intrinsic for storing the result of PTX MMA into a destination pointer
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment