Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
bbbf4207
Unverified
Commit
bbbf4207
authored
Nov 14, 2025
by
guchaoyang
Committed by
GitHub
Nov 14, 2025
Browse files
Merge branch 'main' into dcu
parents
8f4628e0
5eb30a4f
Changes
286
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1075 additions
and
906 deletions
+1075
-906
src/op/builtin.h
src/op/builtin.h
+64
-6
src/op/copy.cc
src/op/copy.cc
+15
-9
src/op/copy.h
src/op/copy.h
+7
-45
src/op/fill.cc
src/op/fill.cc
+44
-12
src/op/fill.h
src/op/fill.h
+3
-17
src/op/finalize_reducer.cc
src/op/finalize_reducer.cc
+3
-3
src/op/finalize_reducer.h
src/op/finalize_reducer.h
+5
-17
src/op/gemm.cc
src/op/gemm.cc
+341
-306
src/op/gemm.h
src/op/gemm.h
+47
-103
src/op/gemm_py.cc
src/op/gemm_py.cc
+234
-80
src/op/gemm_py.h
src/op/gemm_py.h
+42
-75
src/op/gemm_sp.cc
src/op/gemm_sp.cc
+80
-80
src/op/gemm_sp.h
src/op/gemm_sp.h
+30
-54
src/op/logical.cc
src/op/logical.cc
+3
-1
src/op/math.cc
src/op/math.cc
+28
-0
src/op/operator.cc
src/op/operator.cc
+2
-2
src/op/operator.h
src/op/operator.h
+3
-5
src/op/parallel.cc
src/op/parallel.cc
+28
-2
src/op/parallel.h
src/op/parallel.h
+5
-18
src/op/reduce.cc
src/op/reduce.cc
+91
-71
No files found.
src/op/builtin.h
View file @
bbbf4207
...
@@ -241,14 +241,24 @@ TVM_DLL const Op &ptx_wgmma_ss();
...
@@ -241,14 +241,24 @@ TVM_DLL const Op &ptx_wgmma_ss();
/*!
/*!
* \brief tvm intrinsics for ptx tensor core wgmma instructions.
* \brief tvm intrinsics for ptx tensor core wgmma instructions.
*
*
* void ptx_wgmma_rs(StringImm accum_dtype, StringImm wgmma_prefix,
bool
* void ptx_wgmma_rs(StringImm accum_dtype, StringImm wgmma_prefix,
*
a_is_k_major,
bool b_is_k_major, StringImm a_dtype_abbrv, StringImm
* bool b_is_k_major, StringImm a_dtype_abbrv, StringImm
b_dtype_abbrv,
*
b_dtype_abbrv,
StringImm accum_dtype_abbrv, Var A_descriptor, PrimExpr
* StringImm accum_dtype_abbrv, Var A_descriptor, PrimExpr
A_offset, Var
*
A_offset, Var
B_descriptor, Var B_offset, Var C_data, Var C_offset, bool
* B_descriptor, Var B_offset, Var C_data, Var C_offset, bool
scale_out,
*
scale_out,
bool scale_in_a, bool scale_in_b);
* bool scale_in_a, bool scale_in_b);
*/
*/
TVM_DLL
const
Op
&
ptx_wgmma_rs
();
TVM_DLL
const
Op
&
ptx_wgmma_rs
();
/*!
* \brief tvm intrinsic for tcgen05 mma shared-shared instructions.
*/
TVM_DLL
const
Op
&
ptx_tcgen05_mma_ss
();
/*!
* \brief tvm intrinsic for tcgen05 mma tensor-shared instructions.
*/
TVM_DLL
const
Op
&
ptx_tcgen05_mma_ts
();
/*!
/*!
* \brief tvm intrinsics for initializing tensor memory
* \brief tvm intrinsics for initializing tensor memory
*
*
...
@@ -265,6 +275,17 @@ TVM_DLL const Op &ptx_init_tensor_memory();
...
@@ -265,6 +275,17 @@ TVM_DLL const Op &ptx_init_tensor_memory();
*/
*/
TVM_DLL
const
Op
&
ptx_deallocate_tensor_memory
();
TVM_DLL
const
Op
&
ptx_deallocate_tensor_memory
();
/*!
* \brief tvm intrinsic for ptx tensor core mma instructions on SM70.
*
* void ptx_mma_sm70(StringImm shape, StringImm A_layout, StringImm B_layout,
* StringImm A_dtype, StringImm B_dtype, StringImm C_dtype,
* Var multiplicand_a, Expr a_index,
* Var multiplicand_b, Expr b_index,
* Var accumulator, Expr c_index, bool saturate);
*/
TVM_DLL
const
Op
&
ptx_mma_sm70
();
/*!
/*!
* \brief tvm intrinsics for ldmatrix
* \brief tvm intrinsics for ldmatrix
*
*
...
@@ -361,6 +382,14 @@ TVM_DLL const Op &warpgroup_commit_batch();
...
@@ -361,6 +382,14 @@ TVM_DLL const Op &warpgroup_commit_batch();
*/
*/
TVM_DLL
const
Op
&
warpgroup_wait
();
TVM_DLL
const
Op
&
warpgroup_wait
();
/*!
* \brief Fence accumulator operand registers for upcoming WGMMA operations
*
* warpgroup_fence_operand(dtype, ptr, offset, num_regs)
*
*/
TVM_DLL
const
Op
&
warpgroup_fence_operand
();
/*!
/*!
* \brief Return the canonical lane index for the calling thread.
* \brief Return the canonical lane index for the calling thread.
*
*
...
@@ -494,7 +523,21 @@ TVM_DLL const Op &tl_shuffle_elect();
...
@@ -494,7 +523,21 @@ TVM_DLL const Op &tl_shuffle_elect();
* This op is used to represent a descriptor initialization operation in
* This op is used to represent a descriptor initialization operation in
* tilelang.
* tilelang.
*/
*/
TVM_DLL
const
Op
&
initialize_descriptor
();
TVM_DLL
const
Op
&
initialize_wgmma_descriptor
();
/*!
* \brief tilelang intrinsic for initializing a descriptor buffer for
* tcgen05 mma.
*/
TVM_DLL
const
Op
&
initialize_tcgen05_descriptor
();
/*!
* \brief tilelang intrinsic for committing UMMA (TCGEN05) barrier arrive.
*
* This op wraps the device-side arrive used to signal completion of MMA work
* to a shared-memory mbarrier. It mirrors CUTLASS's umma_arrive.
*/
TVM_DLL
const
Op
&
tcgen05_mma_arrive
();
/*!
/*!
* \brief tilelang intrinsic for setting the start address of a descriptor
* \brief tilelang intrinsic for setting the start address of a descriptor
...
@@ -505,6 +548,7 @@ TVM_DLL const Op &initialize_descriptor();
...
@@ -505,6 +548,7 @@ TVM_DLL const Op &initialize_descriptor();
*/
*/
TVM_DLL
const
Op
&
increase_descriptor_offset
();
TVM_DLL
const
Op
&
increase_descriptor_offset
();
/*!
/*!
* \brief tilelang intrinsic for element-wise atomic addition.
* \brief tilelang intrinsic for element-wise atomic addition.
*
*
...
@@ -513,6 +557,20 @@ TVM_DLL const Op &increase_descriptor_offset();
...
@@ -513,6 +557,20 @@ TVM_DLL const Op &increase_descriptor_offset();
*/
*/
TVM_DLL
const
Op
&
atomicadd_elem_op
();
TVM_DLL
const
Op
&
atomicadd_elem_op
();
/*!
* \brief tilelang intrinsic for assert on device.
*
* This op is used to represent an assert on device
*/
TVM_DLL
const
Op
&
device_assert
();
/*!
* \brief tilelang intrinsic for assert on device with additional message.
*
* This op is used to represent an assert on device with additional message.
*/
TVM_DLL
const
Op
&
device_assert_with_msg
();
}
// namespace tl
}
// namespace tl
}
// namespace tvm
}
// namespace tvm
...
...
src/op/copy.cc
View file @
bbbf4207
...
@@ -130,7 +130,7 @@ template <typename T> static Array<T> ReverseArray(Array<T> array) {
...
@@ -130,7 +130,7 @@ template <typename T> static Array<T> ReverseArray(Array<T> array) {
* @param vmap BufferMap used to resolve RegionOp buffers and ranges.
* @param vmap BufferMap used to resolve RegionOp buffers and ranges.
*/
*/
Copy
::
Copy
(
Array
<
PrimExpr
>
args
,
BufferMap
vmap
)
{
Copy
::
Copy
(
Array
<
PrimExpr
>
args
,
BufferMap
vmap
)
{
ObjectPtr
<
CopyNode
>
node
=
make_object
<
CopyNode
>
();
ObjectPtr
<
CopyNode
>
node
=
tvm
::
ffi
::
make_object
<
CopyNode
>
();
Array
<
Range
>
rgs
[
2
];
Array
<
Range
>
rgs
[
2
];
Buffer
bf
[
2
];
Buffer
bf
[
2
];
for
(
int
i
=
0
;
i
<
2
;
i
++
)
{
for
(
int
i
=
0
;
i
<
2
;
i
++
)
{
...
@@ -169,7 +169,7 @@ Copy::Copy(Array<PrimExpr> args, BufferMap vmap) {
...
@@ -169,7 +169,7 @@ Copy::Copy(Array<PrimExpr> args, BufferMap vmap) {
* @return TileOperator A TileOperator owning the cloned CopyNode.
* @return TileOperator A TileOperator owning the cloned CopyNode.
*/
*/
TileOperator
CopyNode
::
Clone
()
const
{
TileOperator
CopyNode
::
Clone
()
const
{
auto
op
=
make_object
<
CopyNode
>
(
*
this
);
auto
op
=
tvm
::
ffi
::
make_object
<
CopyNode
>
(
*
this
);
if
(
par_op_
.
defined
())
{
if
(
par_op_
.
defined
())
{
op
->
par_op_
=
Downcast
<
ParallelOp
>
(
par_op_
->
Clone
());
op
->
par_op_
=
Downcast
<
ParallelOp
>
(
par_op_
->
Clone
());
}
}
...
@@ -401,7 +401,7 @@ LayoutMap CopyNode::InferLayout(const LayoutInferArgs &T,
...
@@ -401,7 +401,7 @@ LayoutMap CopyNode::InferLayout(const LayoutInferArgs &T,
using
namespace
tvm
::
transform
;
using
namespace
tvm
::
transform
;
PassContext
pass_ctx
=
PassContext
::
Current
();
PassContext
pass_ctx
=
PassContext
::
Current
();
bool
disable_tma_lower
=
bool
disable_tma_lower
=
pass_ctx
->
GetConfig
<
b
ool
>
(
kDisableTMALower
,
false
).
value
();
pass_ctx
->
GetConfig
<
B
ool
>
(
kDisableTMALower
,
Bool
(
false
)
)
.
value
();
auto
copy_inst
=
GetCopyInst
(
target
,
disable_tma_lower
||
disable_tma
,
auto
copy_inst
=
GetCopyInst
(
target
,
disable_tma_lower
||
disable_tma
,
T
.
layout_map
,
T
.
analyzer
,
T
.
buffer_oob
);
T
.
layout_map
,
T
.
analyzer
,
T
.
buffer_oob
);
...
@@ -793,7 +793,7 @@ Stmt CopyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
...
@@ -793,7 +793,7 @@ Stmt CopyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
using
namespace
tvm
::
transform
;
using
namespace
tvm
::
transform
;
PassContext
pass_ctx
=
PassContext
::
Current
();
PassContext
pass_ctx
=
PassContext
::
Current
();
bool
disable_tma_lower
=
bool
disable_tma_lower
=
pass_ctx
->
GetConfig
<
b
ool
>
(
kDisableTMALower
,
false
).
value
();
pass_ctx
->
GetConfig
<
B
ool
>
(
kDisableTMALower
,
Bool
(
false
)
)
.
value
();
auto
copy_inst
=
GetCopyInst
(
target
,
disable_tma_lower
||
disable_tma
,
auto
copy_inst
=
GetCopyInst
(
target
,
disable_tma_lower
||
disable_tma
,
T
.
layout_map
,
analyzer
);
T
.
layout_map
,
analyzer
);
if
(
copy_inst
==
CopyInst
::
kTMemLoad
||
copy_inst
==
CopyInst
::
kTMemStore
)
{
if
(
copy_inst
==
CopyInst
::
kTMemLoad
||
copy_inst
==
CopyInst
::
kTMemStore
)
{
...
@@ -1504,7 +1504,12 @@ Stmt CopyNode::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer,
...
@@ -1504,7 +1504,12 @@ Stmt CopyNode::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer,
}
}
auto
inner_box_dim
=
as_const_int
(
desc
.
smem_box
[
0
]);
auto
inner_box_dim
=
as_const_int
(
desc
.
smem_box
[
0
]);
ICHECK
(
inner_box_dim
!=
nullptr
);
if
(
inner_box_dim
==
nullptr
)
{
LOG
(
WARNING
)
<<
"inner_box_dim "
<<
desc
.
smem_box
[
0
]
<<
" can only be a constant integer for TMA bulk copy, "
"fallback to normal copy"
;
return
LowerNormalCopy
(
T
,
analyzer
);
}
int
instruction_dim
=
*
inner_box_dim
;
int
instruction_dim
=
*
inner_box_dim
;
if
(
desc
.
swizzle
==
static_cast
<
int
>
(
CU_TENSOR_MAP_SWIZZLE_64B
))
{
if
(
desc
.
swizzle
==
static_cast
<
int
>
(
CU_TENSOR_MAP_SWIZZLE_64B
))
{
instruction_dim
=
64
/
src
->
dtype
.
bytes
();
instruction_dim
=
64
/
src
->
dtype
.
bytes
();
...
@@ -1722,7 +1727,8 @@ Array<PrimExpr> TMADesc::EncodeCallArgs() const {
...
@@ -1722,7 +1727,8 @@ Array<PrimExpr> TMADesc::EncodeCallArgs() const {
* @param vmap Mapping from original buffer variables to actual Buffer objects.
* @param vmap Mapping from original buffer variables to actual Buffer objects.
*/
*/
Conv2DIm2ColOp
::
Conv2DIm2ColOp
(
Array
<
PrimExpr
>
args
,
BufferMap
vmap
)
{
Conv2DIm2ColOp
::
Conv2DIm2ColOp
(
Array
<
PrimExpr
>
args
,
BufferMap
vmap
)
{
ObjectPtr
<
Conv2DIm2ColOpNode
>
node
=
make_object
<
Conv2DIm2ColOpNode
>
();
ObjectPtr
<
Conv2DIm2ColOpNode
>
node
=
tvm
::
ffi
::
make_object
<
Conv2DIm2ColOpNode
>
();
node
->
src
=
vmap
[
GetVarFromAccessPtr
(
args
[
0
])];
node
->
src
=
vmap
[
GetVarFromAccessPtr
(
args
[
0
])];
node
->
dst
=
vmap
[
GetVarFromAccessPtr
(
args
[
1
])];
node
->
dst
=
vmap
[
GetVarFromAccessPtr
(
args
[
1
])];
node
->
nhw_step
=
args
[
2
];
node
->
nhw_step
=
args
[
2
];
...
@@ -1747,7 +1753,7 @@ Conv2DIm2ColOp::Conv2DIm2ColOp(Array<PrimExpr> args, BufferMap vmap) {
...
@@ -1747,7 +1753,7 @@ Conv2DIm2ColOp::Conv2DIm2ColOp(Array<PrimExpr> args, BufferMap vmap) {
* @return TileOperator A TileOperator containing the cloned Conv2DIm2ColOpNode.
* @return TileOperator A TileOperator containing the cloned Conv2DIm2ColOpNode.
*/
*/
TileOperator
Conv2DIm2ColOpNode
::
Clone
()
const
{
TileOperator
Conv2DIm2ColOpNode
::
Clone
()
const
{
auto
op
=
make_object
<
Conv2DIm2ColOpNode
>
(
*
this
);
auto
op
=
tvm
::
ffi
::
make_object
<
Conv2DIm2ColOpNode
>
(
*
this
);
return
Conv2DIm2ColOp
(
op
);
return
Conv2DIm2ColOp
(
op
);
}
}
...
@@ -1973,9 +1979,9 @@ TIR_REGISTER_TL_OP(Conv2DIm2ColOp, c2d_im2col)
...
@@ -1973,9 +1979,9 @@ TIR_REGISTER_TL_OP(Conv2DIm2ColOp, c2d_im2col)
.
set_attr
<
TCallEffectKind
>
(
"TCallEffectKind"
,
.
set_attr
<
TCallEffectKind
>
(
"TCallEffectKind"
,
Integer
(
CallEffectKind
::
kOpaque
));
Integer
(
CallEffectKind
::
kOpaque
));
TVM_FFI_STATIC_INIT_BLOCK
({
TVM_FFI_STATIC_INIT_BLOCK
(
)
{
CopyNode
::
RegisterReflection
();
CopyNode
::
RegisterReflection
();
Conv2DIm2ColOpNode
::
RegisterReflection
();
Conv2DIm2ColOpNode
::
RegisterReflection
();
}
);
}
}
// namespace tl
}
// namespace tl
}
// namespace tvm
}
// namespace tvm
src/op/copy.h
View file @
bbbf4207
...
@@ -101,8 +101,7 @@ public:
...
@@ -101,8 +101,7 @@ public:
};
};
uint8_t
eviction_policy
;
// Policy for cache eviction
uint8_t
eviction_policy
;
// Policy for cache eviction
static
constexpr
const
char
*
_type_key
=
"tl.Copy"
;
TVM_FFI_DECLARE_OBJECT_INFO_FINAL
(
"tl.Copy"
,
CopyNode
,
TileOperatorNode
);
TVM_DECLARE_FINAL_OBJECT_INFO
(
CopyNode
,
TileOperatorNode
);
static
void
RegisterReflection
()
{
static
void
RegisterReflection
()
{
namespace
refl
=
tvm
::
ffi
::
reflection
;
namespace
refl
=
tvm
::
ffi
::
reflection
;
...
@@ -114,23 +113,6 @@ public:
...
@@ -114,23 +113,6 @@ public:
.
def_ro
(
"coalesced_width"
,
&
CopyNode
::
coalesced_width
);
.
def_ro
(
"coalesced_width"
,
&
CopyNode
::
coalesced_width
);
}
}
bool
SEqualReduce
(
const
CopyNode
*
other
,
SEqualReducer
equal
)
const
{
return
equal
(
src
,
other
->
src
)
&&
equal
(
dst
,
other
->
dst
)
&&
equal
(
src_range
,
other
->
src_range
)
&&
equal
(
dst_range
,
other
->
dst_range
)
&&
equal
(
coalesced_width
,
other
->
coalesced_width
);
}
void
SHashReduce
(
SHashReducer
hash_reduce
)
const
{
hash_reduce
(
src
);
hash_reduce
(
dst
);
hash_reduce
(
src_range
);
hash_reduce
(
dst_range
);
hash_reduce
(
coalesced_width
);
}
static
constexpr
bool
_type_has_method_sequal_reduce
=
true
;
static
constexpr
bool
_type_has_method_shash_reduce
=
true
;
/*!
/*!
* \brief Lower the copy operator to a TIR statement.
* \brief Lower the copy operator to a TIR statement.
* \param T Arguments for lowering.
* \param T Arguments for lowering.
...
@@ -291,7 +273,7 @@ protected:
...
@@ -291,7 +273,7 @@ protected:
class
Copy
:
public
TileOperator
{
class
Copy
:
public
TileOperator
{
public:
public:
TVM_DEFINE_OBJECT_REF_METHODS
(
Copy
,
TileOperator
,
CopyNode
);
TVM_
FFI_
DEFINE_OBJECT_REF_METHODS
_NULLABLE
(
Copy
,
TileOperator
,
CopyNode
);
/*!
/*!
* \brief Constructor.
* \brief Constructor.
...
@@ -323,8 +305,8 @@ public:
...
@@ -323,8 +305,8 @@ public:
PrimExpr
nhw_step
;
// Step size in NHW dimensions
PrimExpr
nhw_step
;
// Step size in NHW dimensions
PrimExpr
c_step
;
// Step size in channel dimension
PrimExpr
c_step
;
// Step size in channel dimension
static
constexpr
const
char
*
_type_key
=
"tl.Conv2DIm2Col"
;
TVM_FFI_DECLARE_OBJECT_INFO_FINAL
(
"tl.Conv2DIm2Col"
,
Conv2DIm2ColOpNode
,
TVM_DECLARE_FINAL_OBJECT_INFO
(
Conv2DIm2ColOpNode
,
TileOperatorNode
);
TileOperatorNode
);
static
void
RegisterReflection
()
{
static
void
RegisterReflection
()
{
namespace
refl
=
tvm
::
ffi
::
reflection
;
namespace
refl
=
tvm
::
ffi
::
reflection
;
...
@@ -338,26 +320,6 @@ public:
...
@@ -338,26 +320,6 @@ public:
.
def_ro
(
"eviction_policy"
,
&
Conv2DIm2ColOpNode
::
eviction_policy
);
.
def_ro
(
"eviction_policy"
,
&
Conv2DIm2ColOpNode
::
eviction_policy
);
}
}
bool
SEqualReduce
(
const
Conv2DIm2ColOpNode
*
other
,
SEqualReducer
equal
)
const
{
return
equal
(
src
,
other
->
src
)
&&
equal
(
dst
,
other
->
dst
)
&&
equal
(
stride
,
other
->
stride
)
&&
equal
(
padding
,
other
->
padding
)
&&
equal
(
dilation
,
other
->
dilation
)
&&
equal
(
kernel
,
other
->
kernel
)
&&
equal
(
eviction_policy
,
other
->
eviction_policy
);
}
void
SHashReduce
(
SHashReducer
hash_reduce
)
const
{
hash_reduce
(
src
);
hash_reduce
(
dst
);
hash_reduce
(
stride
);
hash_reduce
(
padding
);
hash_reduce
(
dilation
);
hash_reduce
(
kernel
);
hash_reduce
(
eviction_policy
);
}
static
constexpr
bool
_type_has_method_sequal_reduce
=
true
;
static
constexpr
bool
_type_has_method_shash_reduce
=
true
;
/*!
/*!
* \brief Lower to TIR statement.
* \brief Lower to TIR statement.
*/
*/
...
@@ -378,8 +340,8 @@ public:
...
@@ -378,8 +340,8 @@ public:
class
Conv2DIm2ColOp
:
public
TileOperator
{
class
Conv2DIm2ColOp
:
public
TileOperator
{
public:
public:
TVM_DEFINE_OBJECT_REF_METHODS
(
Conv2DIm2ColOp
,
TileOperator
,
TVM_
FFI_
DEFINE_OBJECT_REF_METHODS
_NULLABLE
(
Conv2DIm2ColOp
,
TileOperator
,
Conv2DIm2ColOpNode
);
Conv2DIm2ColOpNode
);
TVM_DLL
Conv2DIm2ColOp
(
Array
<
PrimExpr
>
args
,
BufferMap
vmap
);
TVM_DLL
Conv2DIm2ColOp
(
Array
<
PrimExpr
>
args
,
BufferMap
vmap
);
static
const
Op
&
Get
();
static
const
Op
&
Get
();
};
};
...
@@ -387,4 +349,4 @@ public:
...
@@ -387,4 +349,4 @@ public:
}
// namespace tl
}
// namespace tl
}
// namespace tvm
}
// namespace tvm
#endif // TVM_TL_OP_COPY_H_
#endif // TVM_TL_OP_COPY_H_
\ No newline at end of file
src/op/fill.cc
View file @
bbbf4207
...
@@ -17,6 +17,7 @@
...
@@ -17,6 +17,7 @@
#include "../transform/loop_partition.h"
#include "../transform/loop_partition.h"
#include "../transform/loop_vectorize.h"
#include "../transform/loop_vectorize.h"
#include "builtin.h"
#include "builtin.h"
#include "region.h"
namespace
tvm
{
namespace
tvm
{
namespace
tl
{
namespace
tl
{
...
@@ -60,9 +61,32 @@ using namespace tir;
...
@@ -60,9 +61,32 @@ using namespace tir;
* of bounds.
* of bounds.
*/
*/
Fill
::
Fill
(
Array
<
PrimExpr
>
args
,
BufferMap
vmap
)
{
Fill
::
Fill
(
Array
<
PrimExpr
>
args
,
BufferMap
vmap
)
{
ObjectPtr
<
FillNode
>
node
=
make_object
<
FillNode
>
();
ObjectPtr
<
FillNode
>
node
=
tvm
::
ffi
::
make_object
<
FillNode
>
();
if
(
args
[
0
]
->
IsInstance
<
BufferLoadNode
>
())
{
// Case 1: Region descriptor call (tl.region)
if
(
const
auto
*
call
=
args
[
0
].
as
<
CallNode
>
())
{
if
(
call
->
op
.
same_as
(
RegionOp
::
Get
()))
{
auto
region
=
RegionOp
(
call
->
args
,
vmap
);
node
->
dst
=
region
->
GetBuffer
();
node
->
region
=
region
->
GetRanges
();
}
else
if
(
call
->
op
.
same_as
(
builtin
::
tvm_access_ptr
()))
{
node
->
dst
=
vmap
[
GetVarFromAccessPtr
(
args
[
0
])];
for
(
int
i
=
0
;
i
<
node
->
dst
->
shape
.
size
();
i
++
)
{
node
->
region
.
push_back
(
Range
(
0
,
node
->
dst
->
shape
[
i
]));
}
}
else
{
ICHECK
(
false
)
<<
"Unsupported call op in tl.fill: "
<<
Downcast
<
Op
>
(
call
->
op
)
->
name
;
}
// Case 2: Explicit BufferRegion (legacy path)
}
else
if
(
args
[
0
]
->
IsInstance
<
BufferRegionNode
>
())
{
auto
region
=
Downcast
<
BufferRegion
>
(
args
[
0
]);
node
->
dst
=
region
->
buffer
;
node
->
region
=
region
->
region
;
// Case 3: Vector/scalar region expressed via BufferLoad indices
}
else
if
(
args
[
0
]
->
IsInstance
<
BufferLoadNode
>
())
{
auto
buffer_load
=
Downcast
<
BufferLoad
>
(
args
[
0
]);
auto
buffer_load
=
Downcast
<
BufferLoad
>
(
args
[
0
]);
for
(
const
auto
&
index
:
buffer_load
->
indices
)
{
for
(
const
auto
&
index
:
buffer_load
->
indices
)
{
if
(
const
auto
*
ramp
=
index
.
as
<
RampNode
>
())
{
if
(
const
auto
*
ramp
=
index
.
as
<
RampNode
>
())
{
...
@@ -77,6 +101,7 @@ Fill::Fill(Array<PrimExpr> args, BufferMap vmap) {
...
@@ -77,6 +101,7 @@ Fill::Fill(Array<PrimExpr> args, BufferMap vmap) {
}
}
}
}
node
->
dst
=
buffer_load
->
buffer
;
node
->
dst
=
buffer_load
->
buffer
;
// Case 4: Access pointer, fill the full buffer
}
else
{
}
else
{
node
->
dst
=
vmap
[
GetVarFromAccessPtr
(
args
[
0
])];
node
->
dst
=
vmap
[
GetVarFromAccessPtr
(
args
[
0
])];
for
(
int
i
=
0
;
i
<
node
->
dst
->
shape
.
size
();
i
++
)
{
for
(
int
i
=
0
;
i
<
node
->
dst
->
shape
.
size
();
i
++
)
{
...
@@ -95,14 +120,19 @@ Fill::Fill(Array<PrimExpr> args, BufferMap vmap) {
...
@@ -95,14 +120,19 @@ Fill::Fill(Array<PrimExpr> args, BufferMap vmap) {
<<
" != "
<<
node
->
dst
->
shape
.
size
();
<<
" != "
<<
node
->
dst
->
shape
.
size
();
for
(
int
i
=
0
;
i
<
node
->
region
.
size
();
i
++
)
{
for
(
int
i
=
0
;
i
<
node
->
region
.
size
();
i
++
)
{
// bound check if region is static
// bound check if region is static
if
(
node
->
region
[
i
]
->
min
.
as
<
IntImm
>
())
{
if
(
const
auto
*
min_imm
=
node
->
region
[
i
]
->
min
.
as
<
IntImm
Node
>
())
{
int64_t
min
=
Downcast
<
IntImm
>
(
node
->
region
[
i
]
->
min
)
->
value
;
int64_t
min
=
min_imm
->
value
;
ICHECK_GE
(
min
,
0
)
<<
"region["
<<
i
<<
"] = "
<<
min
<<
" < 0"
;
ICHECK_GE
(
min
,
0
)
<<
"region["
<<
i
<<
"] = "
<<
min
<<
" < 0"
;
}
}
if
(
node
->
region
[
i
]
->
extent
.
as
<
IntImm
>
())
{
if
(
const
auto
*
extent_imm
=
node
->
region
[
i
]
->
extent
.
as
<
IntImmNode
>
())
{
int64_t
extent
=
Downcast
<
IntImm
>
(
node
->
region
[
i
]
->
extent
)
->
value
;
// Only perform the upper-bound check when the destination shape
ICHECK_LE
(
extent
,
Downcast
<
IntImm
>
(
node
->
dst
->
shape
[
i
])
->
value
)
// extent is also statically known. If the shape is symbolic (e.g., Var),
<<
"region["
<<
i
<<
"] = "
<<
extent
<<
" > "
<<
node
->
dst
->
shape
[
i
];
// skip this static check to avoid invalid downcasts.
if
(
const
auto
*
shape_imm
=
node
->
dst
->
shape
[
i
].
as
<
IntImmNode
>
())
{
ICHECK_LE
(
extent_imm
->
value
,
shape_imm
->
value
)
<<
"region["
<<
i
<<
"] = "
<<
extent_imm
->
value
<<
" > "
<<
node
->
dst
->
shape
[
i
];
}
}
}
}
}
data_
=
std
::
move
(
node
);
data_
=
std
::
move
(
node
);
...
@@ -117,7 +147,7 @@ Fill::Fill(Array<PrimExpr> args, BufferMap vmap) {
...
@@ -117,7 +147,7 @@ Fill::Fill(Array<PrimExpr> args, BufferMap vmap) {
* @return TileOperator A TileOperator that owns the copied FillNode.
* @return TileOperator A TileOperator that owns the copied FillNode.
*/
*/
TileOperator
FillNode
::
Clone
()
const
{
TileOperator
FillNode
::
Clone
()
const
{
auto
op
=
make_object
<
FillNode
>
(
*
this
);
auto
op
=
tvm
::
ffi
::
make_object
<
FillNode
>
(
*
this
);
return
Fill
(
op
);
return
Fill
(
op
);
}
}
...
@@ -140,7 +170,8 @@ For FillNode::MakeSIMTLoop(arith::Analyzer *analyzer) const {
...
@@ -140,7 +170,8 @@ For FillNode::MakeSIMTLoop(arith::Analyzer *analyzer) const {
for
(
int
i
=
0
;
i
<
ndim
;
i
++
)
{
for
(
int
i
=
0
;
i
<
ndim
;
i
++
)
{
Var
var
=
Var
(
std
::
string
{
char
(
'i'
+
i
)},
region
[
i
]
->
extent
->
dtype
);
Var
var
=
Var
(
std
::
string
{
char
(
'i'
+
i
)},
region
[
i
]
->
extent
->
dtype
);
loop_vars
.
push_back
({
region
[
i
],
var
,
IterVarType
::
kDataPar
});
loop_vars
.
push_back
({
region
[
i
],
var
,
IterVarType
::
kDataPar
});
dst_indices
.
push_back
(
var
);
// Offset the loop induction variable by region min to honor sliced regions
dst_indices
.
push_back
(
region
[
i
]
->
min
+
var
);
}
}
Stmt
body
=
BufferStore
(
dst
,
value
,
dst_indices
);
Stmt
body
=
BufferStore
(
dst
,
value
,
dst_indices
);
for
(
int
i
=
ndim
-
1
;
i
>=
0
;
i
--
)
{
for
(
int
i
=
ndim
-
1
;
i
>=
0
;
i
--
)
{
...
@@ -202,6 +233,7 @@ Stmt FillNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
...
@@ -202,6 +233,7 @@ Stmt FillNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
return
vectorized_thread_loop
;
return
vectorized_thread_loop
;
}
else
{
}
else
{
LOG
(
FATAL
)
<<
"Unsupported scope "
<<
dst
.
scope
();
LOG
(
FATAL
)
<<
"Unsupported scope "
<<
dst
.
scope
();
return
Stmt
();
}
}
}
}
...
@@ -226,7 +258,7 @@ TIR_REGISTER_TL_OP(Fill, fill)
...
@@ -226,7 +258,7 @@ TIR_REGISTER_TL_OP(Fill, fill)
.
set_attr
<
TCallEffectKind
>
(
"TCallEffectKind"
,
.
set_attr
<
TCallEffectKind
>
(
"TCallEffectKind"
,
Integer
(
CallEffectKind
::
kOpaque
));
Integer
(
CallEffectKind
::
kOpaque
));
TVM_FFI_STATIC_INIT_BLOCK
({
FillNode
::
RegisterReflection
();
}
);
TVM_FFI_STATIC_INIT_BLOCK
(
)
{
FillNode
::
RegisterReflection
();
}
}
// namespace tl
}
// namespace tl
}
// namespace tvm
}
// namespace tvm
\ No newline at end of file
src/op/fill.h
View file @
bbbf4207
...
@@ -20,8 +20,7 @@ public:
...
@@ -20,8 +20,7 @@ public:
tir
::
Buffer
dst
;
///< Destination buffer to fill
tir
::
Buffer
dst
;
///< Destination buffer to fill
PrimExpr
value
;
///< Value to fill with
PrimExpr
value
;
///< Value to fill with
Array
<
Range
>
region
;
///< Region to fill within the buffer
Array
<
Range
>
region
;
///< Region to fill within the buffer
static
constexpr
const
char
*
_type_key
=
"tl.Fill"
;
TVM_FFI_DECLARE_OBJECT_INFO_FINAL
(
"tl.Fill"
,
FillNode
,
TileOperatorNode
);
TVM_DECLARE_FINAL_OBJECT_INFO
(
FillNode
,
TileOperatorNode
);
Stmt
Lower
(
const
LowerArgs
&
T
,
arith
::
Analyzer
*
analyzer
)
const
;
Stmt
Lower
(
const
LowerArgs
&
T
,
arith
::
Analyzer
*
analyzer
)
const
;
LayoutMap
InferLayout
(
const
LayoutInferArgs
&
T
,
InferLevel
level
)
const
;
LayoutMap
InferLayout
(
const
LayoutInferArgs
&
T
,
InferLevel
level
)
const
;
...
@@ -35,19 +34,6 @@ public:
...
@@ -35,19 +34,6 @@ public:
.
def_ro
(
"region"
,
&
FillNode
::
region
);
.
def_ro
(
"region"
,
&
FillNode
::
region
);
}
}
bool
SEqualReduce
(
const
FillNode
*
other
,
SEqualReducer
equal
)
const
{
return
equal
(
dst
,
other
->
dst
)
&&
equal
(
value
,
other
->
value
)
&&
equal
(
region
,
other
->
region
);
}
void
SHashReduce
(
SHashReducer
hash_reduce
)
const
{
hash_reduce
(
dst
);
hash_reduce
(
value
);
hash_reduce
(
region
);
}
static
constexpr
bool
_type_has_method_sequal_reduce
=
true
;
static
constexpr
bool
_type_has_method_shash_reduce
=
true
;
TileOperator
Clone
()
const
;
TileOperator
Clone
()
const
;
private:
private:
...
@@ -58,7 +44,7 @@ private:
...
@@ -58,7 +44,7 @@ private:
/// Wrapper class for fill operations
/// Wrapper class for fill operations
class
Fill
:
public
TileOperator
{
class
Fill
:
public
TileOperator
{
public:
public:
TVM_DEFINE_OBJECT_REF_METHODS
(
Fill
,
TileOperator
,
FillNode
);
TVM_
FFI_
DEFINE_OBJECT_REF_METHODS
_NULLABLE
(
Fill
,
TileOperator
,
FillNode
);
TVM_DLL
Fill
(
Array
<
PrimExpr
>
args
,
BufferMap
vmap
);
TVM_DLL
Fill
(
Array
<
PrimExpr
>
args
,
BufferMap
vmap
);
static
const
Op
&
Get
();
static
const
Op
&
Get
();
};
};
...
@@ -66,4 +52,4 @@ public:
...
@@ -66,4 +52,4 @@ public:
}
// namespace tl
}
// namespace tl
}
// namespace tvm
}
// namespace tvm
#endif // TVM_TL_OP_FILL_H_
#endif // TVM_TL_OP_FILL_H_
\ No newline at end of file
src/op/finalize_reducer.cc
View file @
bbbf4207
...
@@ -33,7 +33,7 @@ using namespace tir;
...
@@ -33,7 +33,7 @@ using namespace tir;
* Buffer.
* Buffer.
*/
*/
FinalizeReducerOp
::
FinalizeReducerOp
(
Array
<
PrimExpr
>
args
,
BufferMap
vmap
)
{
FinalizeReducerOp
::
FinalizeReducerOp
(
Array
<
PrimExpr
>
args
,
BufferMap
vmap
)
{
auto
node
=
make_object
<
FinalizeReducerOpNode
>
();
auto
node
=
tvm
::
ffi
::
make_object
<
FinalizeReducerOpNode
>
();
node
->
reducer
=
vmap
[
GetVarFromAccessPtr
(
args
[
0
])];
node
->
reducer
=
vmap
[
GetVarFromAccessPtr
(
args
[
0
])];
node
->
op
=
(
ReducerOpType
)
*
as_const_int
(
args
[
1
]);
node
->
op
=
(
ReducerOpType
)
*
as_const_int
(
args
[
1
]);
data_
=
std
::
move
(
node
);
data_
=
std
::
move
(
node
);
...
@@ -152,7 +152,7 @@ LayoutMap FinalizeReducerOpNode::InferLayout(const LayoutInferArgs &T,
...
@@ -152,7 +152,7 @@ LayoutMap FinalizeReducerOpNode::InferLayout(const LayoutInferArgs &T,
* @return TileOperator A TileOperator that contains a deep copy of this node.
* @return TileOperator A TileOperator that contains a deep copy of this node.
*/
*/
TileOperator
FinalizeReducerOpNode
::
Clone
()
const
{
TileOperator
FinalizeReducerOpNode
::
Clone
()
const
{
auto
node
=
make_object
<
FinalizeReducerOpNode
>
(
*
this
);
auto
node
=
tvm
::
ffi
::
make_object
<
FinalizeReducerOpNode
>
(
*
this
);
return
TileOperator
(
node
);
return
TileOperator
(
node
);
}
}
...
@@ -161,6 +161,6 @@ TIR_REGISTER_TL_OP(FinalizeReducerOp, finalize_reducer)
...
@@ -161,6 +161,6 @@ TIR_REGISTER_TL_OP(FinalizeReducerOp, finalize_reducer)
.
set_attr
<
TCallEffectKind
>
(
"TCallEffectKind"
,
.
set_attr
<
TCallEffectKind
>
(
"TCallEffectKind"
,
Integer
(
CallEffectKind
::
kOpaque
));
Integer
(
CallEffectKind
::
kOpaque
));
TVM_FFI_STATIC_INIT_BLOCK
({
FinalizeReducerOpNode
::
RegisterReflection
();
}
);
TVM_FFI_STATIC_INIT_BLOCK
(
)
{
FinalizeReducerOpNode
::
RegisterReflection
();
}
}
// namespace tl
}
// namespace tl
}
// namespace tvm
}
// namespace tvm
src/op/finalize_reducer.h
View file @
bbbf4207
...
@@ -27,8 +27,8 @@ public:
...
@@ -27,8 +27,8 @@ public:
tir
::
Buffer
reducer
;
tir
::
Buffer
reducer
;
ReducerOpType
op
;
ReducerOpType
op
;
static
constexpr
const
char
*
_type_key
=
"tl.FinalizeReducerOp"
;
TVM_FFI_DECLARE_OBJECT_INFO_FINAL
(
"tl.FinalizeReducerOp"
,
TVM_DECLARE_FINAL_OBJECT_INFO
(
FinalizeReducerOpNode
,
TileOperatorNode
);
FinalizeReducerOpNode
,
TileOperatorNode
);
static
void
RegisterReflection
()
{
static
void
RegisterReflection
()
{
namespace
refl
=
tvm
::
ffi
::
reflection
;
namespace
refl
=
tvm
::
ffi
::
reflection
;
...
@@ -37,18 +37,6 @@ public:
...
@@ -37,18 +37,6 @@ public:
.
def_ro
(
"op"
,
&
FinalizeReducerOpNode
::
op
);
.
def_ro
(
"op"
,
&
FinalizeReducerOpNode
::
op
);
}
}
bool
SEqualReduce
(
const
FinalizeReducerOpNode
*
other
,
SEqualReducer
equal
)
const
{
return
equal
(
reducer
,
other
->
reducer
)
&&
equal
(
op
,
other
->
op
);
}
void
SHashReduce
(
SHashReducer
hash_reduce
)
const
{
hash_reduce
(
reducer
);
hash_reduce
(
op
);
}
static
constexpr
bool
_type_has_method_sequal_reduce
=
true
;
static
constexpr
bool
_type_has_method_shash_reduce
=
true
;
Stmt
Lower
(
const
LowerArgs
&
T
,
arith
::
Analyzer
*
analyzer
)
const
override
;
Stmt
Lower
(
const
LowerArgs
&
T
,
arith
::
Analyzer
*
analyzer
)
const
override
;
LayoutMap
InferLayout
(
const
LayoutInferArgs
&
T
,
LayoutMap
InferLayout
(
const
LayoutInferArgs
&
T
,
InferLevel
level
)
const
override
;
InferLevel
level
)
const
override
;
...
@@ -58,8 +46,8 @@ public:
...
@@ -58,8 +46,8 @@ public:
class
FinalizeReducerOp
:
public
TileOperator
{
class
FinalizeReducerOp
:
public
TileOperator
{
public:
public:
TVM_DEFINE_OBJECT_REF_METHODS
(
FinalizeReducerOp
,
TileOperator
,
TVM_
FFI_
DEFINE_OBJECT_REF_METHODS
_NULLABLE
(
FinalizeReducerOp
,
TileOperator
,
FinalizeReducerOpNode
);
FinalizeReducerOpNode
);
TVM_DLL
FinalizeReducerOp
(
Array
<
PrimExpr
>
args
,
BufferMap
vmap
);
TVM_DLL
FinalizeReducerOp
(
Array
<
PrimExpr
>
args
,
BufferMap
vmap
);
static
const
Op
&
Get
();
static
const
Op
&
Get
();
};
};
...
@@ -67,4 +55,4 @@ public:
...
@@ -67,4 +55,4 @@ public:
}
// namespace tl
}
// namespace tl
}
// namespace tvm
}
// namespace tvm
#endif // TVM_TL_OP_FINALIZE_REDUCER_H_
#endif // TVM_TL_OP_FINALIZE_REDUCER_H_
\ No newline at end of file
src/op/gemm.cc
View file @
bbbf4207
...
@@ -12,77 +12,14 @@
...
@@ -12,77 +12,14 @@
#include <tvm/tir/transform.h>
#include <tvm/tir/transform.h>
#include "../target/utils.h"
#include "../target/utils.h"
#include "region.h"
#include "tcgen5_meta.h"
namespace
tvm
{
namespace
tvm
{
namespace
tl
{
namespace
tl
{
using
namespace
tir
;
using
namespace
tir
;
struct
TCGEN5MMAMeta
{
int
atom_m
,
atom_n
,
atom_k
;
};
// Return {is_success, meta}
static
inline
std
::
pair
<
bool
,
TCGEN5MMAMeta
>
GetTCGEN5MMAMeta
(
int
M
,
int
N
,
int
K
,
DataType
ab_dtype
,
DataType
c_dtype
)
{
// TODO (lei) Currently not all shapes / dtypes are supported for TCGEN5MMA.
#define FAIL \
return { false, TCGEN5MMAMeta{0, 0, 0} }
#define SUCCESS(atom_m, atom_n, atom_k) \
return { \
true, TCGEN5MMAMeta { atom_m, atom_n, atom_k } \
}
std
::
vector
<
int
>
ws_valid_atom_ns
=
{
256
,
128
,
64
};
if
((
ab_dtype
.
is_bfloat16
()
||
ab_dtype
.
is_float16
())
&&
(
c_dtype
.
is_float
()
&&
c_dtype
.
bits
()
==
32
))
{
if
(
K
%
16
!=
0
)
FAIL
;
if
(
M
%
128
==
0
)
{
for
(
int
atom_n
=
256
;
atom_n
>=
16
;
atom_n
-=
16
)
if
(
N
%
atom_n
==
0
)
SUCCESS
(
128
,
atom_n
,
16
);
FAIL
;
}
else
if
(
M
%
64
==
0
)
{
for
(
int
atom_n
:
ws_valid_atom_ns
)
if
(
N
%
atom_n
==
0
)
SUCCESS
(
64
,
atom_n
,
16
);
FAIL
;
}
else
if
(
M
%
32
==
0
)
{
for
(
int
atom_n
:
ws_valid_atom_ns
)
if
(
N
%
atom_n
==
0
)
SUCCESS
(
32
,
atom_n
,
16
);
FAIL
;
}
else
{
FAIL
;
}
}
else
if
((
ab_dtype
.
is_float8_e4m3fn
()
||
ab_dtype
.
is_float8_e5m2
())
&&
(
c_dtype
.
is_float
()
&&
c_dtype
.
bits
()
==
32
))
{
if
(
K
%
32
!=
0
)
FAIL
;
if
(
M
%
128
==
0
)
{
for
(
int
atom_n
=
256
;
atom_n
>=
16
;
atom_n
-=
16
)
if
(
N
%
atom_n
==
0
)
SUCCESS
(
128
,
atom_n
,
32
);
FAIL
;
}
else
if
(
M
%
64
==
0
)
{
for
(
int
atom_n
:
ws_valid_atom_ns
)
if
(
N
%
atom_n
==
0
)
SUCCESS
(
64
,
atom_n
,
32
);
FAIL
;
}
else
if
(
M
%
32
==
0
)
{
for
(
int
atom_n
:
ws_valid_atom_ns
)
if
(
N
%
atom_n
==
0
)
SUCCESS
(
32
,
atom_n
,
32
);
FAIL
;
}
else
{
FAIL
;
}
}
FAIL
;
#undef FAIL
#undef SUCCESS
}
/**
/**
* @brief Construct a Gemm operator from serialized TL arguments and a buffer
* @brief Construct a Gemm operator from serialized TL arguments and a buffer
* map.
* map.
...
@@ -111,42 +48,130 @@ GetTCGEN5MMAMeta(int M, int N, int K, DataType ab_dtype, DataType c_dtype) {
...
@@ -111,42 +48,130 @@ GetTCGEN5MMAMeta(int M, int N, int K, DataType ab_dtype, DataType c_dtype) {
* fails with an ICHECK (runtime assertion). No other validation is
* fails with an ICHECK (runtime assertion). No other validation is
* performed here.
* performed here.
*/
*/
// Normalize a GEMM argument (BufferRegion/BufferLoad/tvm_access_ptr/tl.region)
// to BufferRegion
static
BufferRegion
NormalizeToBufferRegion
(
const
PrimExpr
&
arg
,
const
BufferMap
&
vmap
)
{
// Case 1: Already a BufferRegion
if
(
arg
->
IsInstance
<
BufferRegionNode
>
())
{
return
Downcast
<
BufferRegion
>
(
arg
);
}
// Case 2: BufferLoad — convert indices to ranges (Ramp -> lanes, else
// extent=1)
if
(
const
auto
*
load
=
arg
.
as
<
BufferLoadNode
>
())
{
Array
<
Range
>
ranges
;
for
(
const
PrimExpr
&
index
:
load
->
indices
)
{
if
(
const
auto
*
ramp
=
index
.
as
<
RampNode
>
())
{
ICHECK
(
ramp
->
stride
.
as
<
IntImmNode
>
())
<<
"Ramp stride must be IntImm"
;
ICHECK_EQ
(
ramp
->
stride
.
as
<
IntImmNode
>
()
->
value
,
1
)
<<
"Only stride-1 Ramp is supported in GEMM region conversion"
;
ICHECK
(
ramp
->
lanes
.
as
<
IntImmNode
>
())
<<
"Scalable vector lanes not supported in GEMM region conversion"
;
ranges
.
push_back
(
Range
::
FromMinExtent
(
ramp
->
base
,
ramp
->
lanes
));
}
else
{
ranges
.
push_back
(
Range
::
FromMinExtent
(
index
,
1
));
}
}
return
BufferRegion
(
load
->
buffer
,
ranges
);
}
// Case 3: Call nodes
if
(
const
auto
*
call
=
arg
.
as
<
CallNode
>
())
{
// tl.region(...) — reconstruct via RegionOp
if
(
call
->
op
.
same_as
(
RegionOp
::
Get
()))
{
RegionOp
region
(
call
->
args
,
vmap
);
return
BufferRegion
(
region
->
GetBuffer
(),
region
->
GetRanges
());
}
// builtin.tvm_access_ptr(...) — map var to Buffer and take full region
if
(
call
->
op
.
same_as
(
builtin
::
tvm_access_ptr
()))
{
Var
var
=
Downcast
<
Var
>
(
call
->
args
[
1
]);
Buffer
buf
=
vmap
[
var
];
Array
<
Range
>
ranges
;
for
(
PrimExpr
extent
:
buf
->
shape
)
{
ranges
.
push_back
(
Range
(
IntImm
(
extent
->
dtype
,
0
),
extent
));
}
return
BufferRegion
(
buf
,
ranges
);
}
}
LOG
(
FATAL
)
<<
"Unsupported GEMM argument for BufferRegion: "
<<
arg
;
throw
;
// Unreachable, keeps compiler happy
}
// Build a tvm_access_ptr(handle) to the start of the 2D tile within a
// BufferRegion. Offset is computed from all but the last two dimensions; extent
// is the product of the last two extents. rw_mask: 1=read, 2=write,
// 3=readwrite.
static
PrimExpr
MakeAccessPtrFromRegion
(
const
BufferRegion
&
region
,
int
rw_mask
)
{
Buffer
buf
=
region
->
buffer
;
int
ndim
=
static_cast
<
int
>
(
buf
->
shape
.
size
());
ICHECK
(
ndim
>=
2
)
<<
"GEMM expects buffers with at least 2 dims"
;
// Compute row-major strides
std
::
vector
<
PrimExpr
>
strides
(
ndim
);
PrimExpr
one
=
make_const
(
buf
->
shape
[
0
].
dtype
(),
1
);
PrimExpr
cur
=
one
;
for
(
int
i
=
ndim
-
1
;
i
>=
0
;
--
i
)
{
strides
[
i
]
=
cur
;
cur
=
cur
*
buf
->
shape
[
i
];
}
// Offset: sum_{i in [0..ndim-3]} min_i * stride_i
PrimExpr
offset
=
make_const
(
buf
->
shape
[
0
].
dtype
(),
0
);
for
(
int
i
=
0
;
i
<
ndim
-
2
;
++
i
)
{
offset
=
offset
+
region
->
region
[
i
]
->
min
*
strides
[
i
];
}
// Extent: last two extents product (elements)
PrimExpr
extent
=
region
->
region
[
ndim
-
2
]
->
extent
*
region
->
region
[
ndim
-
1
]
->
extent
;
// ptype and return handle
PrimExpr
ptype
=
tir
::
TypeAnnotation
(
buf
->
dtype
);
Array
<
PrimExpr
>
acc_args
{
ptype
,
buf
->
data
,
offset
,
extent
,
IntImm
(
DataType
::
Int
(
32
),
rw_mask
)};
return
Call
(
DataType
::
Handle
(),
builtin
::
tvm_access_ptr
(),
acc_args
);
}
Gemm
::
Gemm
(
Array
<
PrimExpr
>
args
,
BufferMap
vmap
)
{
Gemm
::
Gemm
(
Array
<
PrimExpr
>
args
,
BufferMap
vmap
)
{
ObjectPtr
<
GemmNode
>
node
=
make_object
<
GemmNode
>
();
ObjectPtr
<
GemmNode
>
node
=
tvm
::
ffi
::
make_object
<
GemmNode
>
();
node
->
Aptr
=
args
[
0
];
node
->
aRegion_
=
NormalizeToBufferRegion
(
args
[
0
],
vmap
);
node
->
Bptr
=
args
[
1
];
node
->
bRegion_
=
NormalizeToBufferRegion
(
args
[
1
],
vmap
);
node
->
Cptr
=
args
[
2
];
node
->
cRegion_
=
NormalizeToBufferRegion
(
args
[
2
],
vmap
);
node
->
A
=
vmap
[
GetVarFromAccessPtr
(
node
->
Aptr
)];
node
->
B
=
vmap
[
GetVarFromAccessPtr
(
node
->
Bptr
)];
node
->
a_
=
node
->
aRegion_
->
buffer
;
node
->
C
=
vmap
[
GetVarFromAccessPtr
(
node
->
Cptr
)];
node
->
b_
=
node
->
bRegion_
->
buffer
;
node
->
trans_A
=
args
[
3
].
as
<
Bool
>
().
value
();
node
->
c_
=
node
->
cRegion_
->
buffer
;
node
->
trans_B
=
args
[
4
].
as
<
Bool
>
().
value
();
node
->
transA_
=
args
[
3
].
as
<
Bool
>
().
value
();
node
->
M
=
args
[
5
].
as
<
IntImm
>
().
value
()
->
value
;
node
->
transB_
=
args
[
4
].
as
<
Bool
>
().
value
();
node
->
N
=
args
[
6
].
as
<
IntImm
>
().
value
()
->
value
;
node
->
m_
=
args
[
5
].
as
<
IntImm
>
().
value
()
->
value
;
node
->
K
=
args
[
7
].
as
<
IntImm
>
().
value
()
->
value
;
node
->
n_
=
args
[
6
].
as
<
IntImm
>
().
value
()
->
value
;
node
->
policy
=
GemmWarpPolicy
(
args
[
8
].
as
<
IntImm
>
().
value
()
->
value
);
node
->
k_
=
args
[
7
].
as
<
IntImm
>
().
value
()
->
value
;
node
->
clear_accum
=
args
[
9
].
as
<
PrimExpr
>
().
value
();
node
->
policy_
=
GemmWarpPolicy
(
args
[
8
].
as
<
IntImm
>
().
value
()
->
value
);
node
->
stride_A
=
args
[
10
].
as
<
IntImm
>
().
value
()
->
value
;
node
->
clearAccum_
=
args
[
9
].
as
<
PrimExpr
>
().
value
();
node
->
stride_B
=
args
[
11
].
as
<
IntImm
>
().
value
()
->
value
;
node
->
strideA_
=
args
[
10
].
as
<
IntImm
>
().
value
()
->
value
;
node
->
offset_A
=
args
[
12
].
as
<
IntImm
>
().
value
()
->
value
;
node
->
strideB_
=
args
[
11
].
as
<
IntImm
>
().
value
()
->
value
;
node
->
offset_B
=
args
[
13
].
as
<
IntImm
>
().
value
()
->
value
;
node
->
offsetA_
=
args
[
12
].
as
<
IntImm
>
().
value
()
->
value
;
node
->
offsetB_
=
args
[
13
].
as
<
IntImm
>
().
value
()
->
value
;
if
(
args
.
size
()
>
14
)
{
if
(
args
.
size
()
>
14
)
{
node
->
kPack
=
args
[
14
].
as
<
IntImm
>
().
value
()
->
value
;
node
->
kPack
_
=
args
[
14
].
as
<
IntImm
>
().
value
()
->
value
;
if
(
node
->
kPack
!=
1
&&
node
->
kPack
!=
2
)
{
if
(
node
->
kPack
_
!=
1
&&
node
->
kPack
_
!=
2
)
{
ICHECK
(
false
)
<<
"kPack must be 1 or 2"
;
ICHECK
(
false
)
<<
"kPack must be 1 or 2"
;
}
}
}
}
if
(
args
.
size
()
>
15
)
{
if
(
args
.
size
()
>
15
)
{
node
->
wg
_w
ait
=
args
[
15
].
as
<
IntImm
>
().
value
()
->
value
;
node
->
wg
W
ait
_
=
args
[
15
].
as
<
IntImm
>
().
value
()
->
value
;
}
}
node
->
mbar
p
tr
=
args
[
16
];
node
->
mbar
P
tr
_
=
args
[
16
];
if
(
node
->
mbar
p
tr
.
as
<
CallNode
>
())
{
if
(
node
->
mbar
P
tr
_
.
as
<
CallNode
>
())
{
node
->
mbar
=
vmap
[
GetVarFromAccessPtr
(
node
->
mbar
p
tr
)];
node
->
mbar
_
=
vmap
[
GetVarFromAccessPtr
(
node
->
mbar
P
tr
_
)];
}
else
{
}
else
{
node
->
mbar
=
std
::
nullopt
;
node
->
mbar
_
=
std
::
nullopt
;
}
}
node
->
C_
coords
=
Array
<
PrimExpr
>
(
node
->
c
C
oords
_
=
Array
<
PrimExpr
>
(
{
args
[
17
].
as
<
PrimExpr
>
().
value
(),
args
[
18
].
as
<
PrimExpr
>
().
value
()});
{
args
[
17
].
as
<
PrimExpr
>
().
value
(),
args
[
18
].
as
<
PrimExpr
>
().
value
()});
data_
=
std
::
move
(
node
);
data_
=
std
::
move
(
node
);
}
}
...
@@ -160,46 +185,45 @@ Gemm::Gemm(Array<PrimExpr> args, BufferMap vmap) {
...
@@ -160,46 +185,45 @@ Gemm::Gemm(Array<PrimExpr> args, BufferMap vmap) {
* @return TileOperator A Gemm operator that owns a copy of this node.
* @return TileOperator A Gemm operator that owns a copy of this node.
*/
*/
TileOperator
GemmNode
::
Clone
()
const
{
TileOperator
GemmNode
::
Clone
()
const
{
auto
op
=
make_object
<
GemmNode
>
(
*
this
);
auto
op
=
tvm
::
ffi
::
make_object
<
GemmNode
>
(
*
this
);
return
Gemm
(
op
);
return
Gemm
(
op
);
}
}
bool
GemmNode
::
A
llowT
CGEN5MMA
(
Target
target
)
const
{
bool
GemmNode
::
a
llowT
cgen5Mma
(
Target
target
)
const
{
return
TargetIsSm100
(
target
)
&&
return
TargetIsSm100
(
target
)
&&
((
A
.
scope
()
==
"shared.dyn"
||
A
.
scope
()
==
"shared"
||
((
a_
.
scope
()
==
"shared.dyn"
||
a_
.
scope
()
==
"shared"
||
A
.
scope
()
==
"shared.tmem"
)
&&
a_
.
scope
()
==
"shared.tmem"
)
&&
(
B
.
scope
()
==
"shared.dyn"
||
B
.
scope
()
==
"shared"
)
&&
(
b_
.
scope
()
==
"shared.dyn"
||
b_
.
scope
()
==
"shared"
)
&&
C
.
scope
()
==
"shared.tmem"
)
&&
c_
.
scope
()
==
"shared.tmem"
)
&&
GetTCGEN5MMAMeta
(
M
,
N
,
K
,
A
->
dtype
,
C
->
dtype
).
first
;
GetTCGEN5MMAMeta
(
m_
,
n_
,
k_
,
a_
->
dtype
,
c_
->
dtype
).
first
;
}
}
bool
GemmNode
::
A
llowW
GMMA
(
int
block_size
,
Target
target
)
const
{
bool
GemmNode
::
a
llowW
gmma
(
int
block_size
,
Target
target
)
const
{
tvm
::
transform
::
PassContext
ctxt
=
tvm
::
transform
::
PassContext
::
Current
();
tvm
::
transform
::
PassContext
ctxt
=
tvm
::
transform
::
PassContext
::
Current
();
int
warp_size
=
TargetGetWarpSize
(
target
);
int
warp_size
=
TargetGetWarpSize
(
target
);
int
num_warps
=
block_size
/
warp_size
;
int
num_warps
=
block_size
/
warp_size
;
return
!
ctxt
->
GetConfig
(
kDisableWGMMA
,
Optional
<
Bool
>
()).
value_or
(
false
)
&&
return
!
ctxt
->
GetConfig
(
kDisableWGMMA
,
Optional
<
Bool
>
()).
value_or
(
false
)
&&
TargetIsHopper
(
target
)
&&
(
this
->
M
>=
64
)
&&
(
num_warps
%
4
==
0
)
&&
TargetIsHopper
(
target
)
&&
(
this
->
m_
>=
64
)
&&
(
num_warps
%
4
==
0
)
&&
C
heckW
GMMA
();
c
heckW
gmma
();
}
}
GemmInst
GemmNode
::
GetGemmInst
(
int
block_size
,
Target
target
)
const
{
GemmInst
GemmNode
::
getGemmInst
(
int
block_size
,
Target
target
)
const
{
bool
allow_tcgen5mma
=
AllowTCGEN5MMA
(
target
);
if
(
allowTcgen5Mma
(
target
))
{
bool
allow_wgmma
=
AllowWGMMA
(
block_size
,
target
);
if
(
allow_tcgen5mma
)
{
return
GemmInst
::
kTCGEN5MMA
;
return
GemmInst
::
kTCGEN5MMA
;
}
else
if
(
allow
_w
gmma
)
{
}
else
if
(
allow
W
gmma
(
block_size
,
target
)
)
{
return
GemmInst
::
kWGMMA
;
return
GemmInst
::
kWGMMA
;
}
else
if
(
TargetIsCDNA
(
target
))
{
}
else
if
(
TargetIsCDNA
(
target
))
{
return
GemmInst
::
kMFMA
;
return
GemmInst
::
kMFMA
;
}
else
if
(
TargetIsCuda
(
target
))
{
}
else
if
(
TargetIsCuda
(
target
))
{
return
GemmInst
::
kMMA
;
return
GemmInst
::
kMMA
;
}
else
{
}
else
{
ICHECK
(
0
)
<<
"Unsupported target for gemm: "
<<
target
->
str
();
ICHECK
(
0
)
<<
"Unsupported target for gemm: "
<<
target
;
return
GemmInst
::
kMMA
;
}
}
}
}
std
::
pair
<
int
,
int
>
GemmWarpPolicyNode
::
C
omputeWarpPartition
(
std
::
pair
<
int
,
int
>
GemmWarpPolicyNode
::
c
omputeWarpPartition
(
int
M
,
int
N
,
int
block_size
,
Target
target
,
GemmInst
gemm_inst
)
const
{
int
M
,
int
N
,
int
block_size
,
Target
target
,
GemmInst
gemm_inst
)
const
{
int
num_warps
=
block_size
/
TargetGetWarpSize
(
target
);
int
num_warps
=
block_size
/
TargetGetWarpSize
(
target
);
if
(
gemm_inst
==
GemmInst
::
kTCGEN5MMA
)
{
if
(
gemm_inst
==
GemmInst
::
kTCGEN5MMA
)
{
...
@@ -208,7 +232,10 @@ std::pair<int, int> GemmWarpPolicyNode::ComputeWarpPartition(
...
@@ -208,7 +232,10 @@ std::pair<int, int> GemmWarpPolicyNode::ComputeWarpPartition(
int
m_warp
=
1
,
n_warp
=
1
;
int
m_warp
=
1
,
n_warp
=
1
;
constexpr
int
kMPerWarp
=
16
;
// Rows processed by a single warp
constexpr
int
kMPerWarp
=
16
;
// Rows processed by a single warp
constexpr
int
kNPerWarp
=
8
;
// Columns processed by a single warp
int
kNPerWarp
=
8
;
// Columns processed by a single warp
if
(
TargetIsVolta
(
target
))
{
kNPerWarp
=
16
;
}
ICHECK
(
M
%
kMPerWarp
==
0
)
ICHECK
(
M
%
kMPerWarp
==
0
)
<<
"M must be divisible by "
<<
kMPerWarp
<<
", but got "
<<
M
;
<<
"M must be divisible by "
<<
kMPerWarp
<<
", but got "
<<
M
;
ICHECK
(
N
%
kNPerWarp
==
0
)
ICHECK
(
N
%
kNPerWarp
==
0
)
...
@@ -408,51 +435,52 @@ std::pair<int, int> GemmWarpPolicyNode::ComputeWarpPartition(
...
@@ -408,51 +435,52 @@ std::pair<int, int> GemmWarpPolicyNode::ComputeWarpPartition(
* @return true if WGMMA is supported for the current buffers, dtypes, and
* @return true if WGMMA is supported for the current buffers, dtypes, and
* transpose/shape constraints; false otherwise.
* transpose/shape constraints; false otherwise.
*/
*/
bool
GemmNode
::
C
heckW
GMMA
()
const
{
bool
GemmNode
::
c
heckW
gmma
()
const
{
if
(
B
.
scope
()
!=
"shared.dyn"
&&
B
.
scope
()
!=
"shared"
)
{
if
(
b_
.
scope
()
!=
"shared.dyn"
&&
b_
.
scope
()
!=
"shared"
)
{
return
false
;
return
false
;
}
}
if
(
C
->
dtype
==
DataType
::
Float
(
16
))
{
if
(
c_
->
dtype
==
DataType
::
Float
(
16
))
{
if
(
A
->
dtype
==
DataType
::
Float
(
16
)
&&
B
->
dtype
==
DataType
::
Float
(
16
))
if
(
a_
->
dtype
==
DataType
::
Float
(
16
)
&&
b_
->
dtype
==
DataType
::
Float
(
16
))
return
K
%
16
==
0
;
return
k_
%
16
==
0
;
else
if
(
A
->
dtype
.
is_float8_e4m3
()
&&
B
->
dtype
.
is_float8_e4m3
())
else
if
(
a_
->
dtype
.
is_float8_e4m3
()
&&
b_
->
dtype
.
is_float8_e4m3
())
return
(
!
trans
_
A
)
&&
trans
_
B
&&
K
%
32
==
0
;
return
(
!
transA
_
)
&&
transB
_
&&
k_
%
32
==
0
;
else
if
(
A
->
dtype
.
is_float8_e4m3
()
&&
B
->
dtype
.
is_float8_e5m2
())
else
if
(
a_
->
dtype
.
is_float8_e4m3
()
&&
b_
->
dtype
.
is_float8_e5m2
())
return
(
!
trans
_
A
)
&&
trans
_
B
&&
K
%
32
==
0
;
return
(
!
transA
_
)
&&
transB
_
&&
k_
%
32
==
0
;
else
if
(
A
->
dtype
.
is_float8_e5m2
()
&&
B
->
dtype
.
is_float8_e4m3
())
else
if
(
a_
->
dtype
.
is_float8_e5m2
()
&&
b_
->
dtype
.
is_float8_e4m3
())
return
(
!
trans
_
A
)
&&
trans
_
B
&&
K
%
32
==
0
;
return
(
!
transA
_
)
&&
transB
_
&&
k_
%
32
==
0
;
else
if
(
A
->
dtype
.
is_float8_e5m2
()
&&
B
->
dtype
.
is_float8_e5m2
())
else
if
(
a_
->
dtype
.
is_float8_e5m2
()
&&
b_
->
dtype
.
is_float8_e5m2
())
return
(
!
trans
_
A
)
&&
trans
_
B
&&
K
%
32
==
0
;
return
(
!
transA
_
)
&&
transB
_
&&
k_
%
32
==
0
;
else
else
return
false
;
return
false
;
}
else
if
(
C
->
dtype
==
DataType
::
Float
(
32
))
{
}
else
if
(
c_
->
dtype
==
DataType
::
Float
(
32
))
{
if
(
A
->
dtype
==
DataType
::
Float
(
16
)
&&
B
->
dtype
==
DataType
::
Float
(
16
))
if
(
a_
->
dtype
==
DataType
::
Float
(
16
)
&&
b_
->
dtype
==
DataType
::
Float
(
16
))
return
K
%
16
==
0
;
return
k_
%
16
==
0
;
else
if
(
A
->
dtype
==
DataType
::
BFloat
(
16
)
&&
else
if
(
a_
->
dtype
==
DataType
::
BFloat
(
16
)
&&
B
->
dtype
==
DataType
::
BFloat
(
16
))
b_
->
dtype
==
DataType
::
BFloat
(
16
))
return
K
%
16
==
0
;
return
k_
%
16
==
0
;
else
if
(
A
->
dtype
==
DataType
::
Float
(
32
)
&&
B
->
dtype
==
DataType
::
Float
(
32
))
else
if
(
a_
->
dtype
==
DataType
::
Float
(
32
)
&&
return
(
!
trans_A
)
&&
trans_B
&&
K
%
8
==
0
;
b_
->
dtype
==
DataType
::
Float
(
32
))
else
if
(
A
->
dtype
.
is_float8_e4m3
()
&&
B
->
dtype
.
is_float8_e4m3
())
return
(
!
transA_
)
&&
transB_
&&
k_
%
8
==
0
;
return
(
!
trans_A
)
&&
trans_B
&&
K
%
32
==
0
;
else
if
(
a_
->
dtype
.
is_float8_e4m3
()
&&
b_
->
dtype
.
is_float8_e4m3
())
else
if
(
A
->
dtype
.
is_float8_e4m3
()
&&
B
->
dtype
.
is_float8_e5m2
())
return
(
!
transA_
)
&&
transB_
&&
k_
%
32
==
0
;
return
(
!
trans_A
)
&&
trans_B
&&
K
%
32
==
0
;
else
if
(
a_
->
dtype
.
is_float8_e4m3
()
&&
b_
->
dtype
.
is_float8_e5m2
())
else
if
(
A
->
dtype
.
is_float8_e5m2
()
&&
B
->
dtype
.
is_float8_e4m3
())
return
(
!
transA_
)
&&
transB_
&&
k_
%
32
==
0
;
return
(
!
trans_A
)
&&
trans_B
&&
K
%
32
==
0
;
else
if
(
a_
->
dtype
.
is_float8_e5m2
()
&&
b_
->
dtype
.
is_float8_e4m3
())
else
if
(
A
->
dtype
.
is_float8_e5m2
()
&&
B
->
dtype
.
is_float8_e5m2
())
return
(
!
transA_
)
&&
transB_
&&
k_
%
32
==
0
;
return
(
!
trans_A
)
&&
trans_B
&&
K
%
32
==
0
;
else
if
(
a_
->
dtype
.
is_float8_e5m2
()
&&
b_
->
dtype
.
is_float8_e5m2
())
return
(
!
transA_
)
&&
transB_
&&
k_
%
32
==
0
;
else
else
return
false
;
return
false
;
}
else
if
(
C
->
dtype
==
DataType
::
Int
(
32
))
{
}
else
if
(
c_
->
dtype
==
DataType
::
Int
(
32
))
{
if
(
A
->
dtype
==
DataType
::
Int
(
8
)
&&
B
->
dtype
==
DataType
::
Int
(
8
))
if
(
a_
->
dtype
==
DataType
::
Int
(
8
)
&&
b_
->
dtype
==
DataType
::
Int
(
8
))
return
(
!
trans
_
A
)
&&
trans
_
B
&&
K
%
32
==
0
;
return
(
!
transA
_
)
&&
transB
_
&&
k_
%
32
==
0
;
else
if
(
A
->
dtype
==
DataType
::
Int
(
8
)
&&
B
->
dtype
==
DataType
::
UInt
(
8
))
else
if
(
a_
->
dtype
==
DataType
::
Int
(
8
)
&&
b_
->
dtype
==
DataType
::
UInt
(
8
))
return
(
!
trans
_
A
)
&&
trans
_
B
&&
K
%
32
==
0
;
return
(
!
transA
_
)
&&
transB
_
&&
k_
%
32
==
0
;
else
if
(
A
->
dtype
==
DataType
::
UInt
(
8
)
&&
B
->
dtype
==
DataType
::
Int
(
8
))
else
if
(
a_
->
dtype
==
DataType
::
UInt
(
8
)
&&
b_
->
dtype
==
DataType
::
Int
(
8
))
return
(
!
trans
_
A
)
&&
trans
_
B
&&
K
%
32
==
0
;
return
(
!
transA
_
)
&&
transB
_
&&
k_
%
32
==
0
;
else
if
(
A
->
dtype
==
DataType
::
UInt
(
8
)
&&
B
->
dtype
==
DataType
::
UInt
(
8
))
else
if
(
a_
->
dtype
==
DataType
::
UInt
(
8
)
&&
b_
->
dtype
==
DataType
::
UInt
(
8
))
return
(
!
trans
_
A
)
&&
trans
_
B
&&
K
%
32
==
0
;
return
(
!
transA
_
)
&&
transB
_
&&
k_
%
32
==
0
;
else
else
return
false
;
return
false
;
}
else
{
}
else
{
...
@@ -476,8 +504,8 @@ bool GemmNode::CheckWGMMA() const {
...
@@ -476,8 +504,8 @@ bool GemmNode::CheckWGMMA() const {
*/
*/
static
int
GetArchInt
(
Target
target
)
{
static
int
GetArchInt
(
Target
target
)
{
int
arch_int
=
0
;
int
arch_int
=
0
;
auto
s
=
target
->
GetAttr
<
String
>
(
"arch"
);
auto
s
=
target
->
GetAttr
<
tvm
::
ffi
::
String
>
(
"arch"
);
ICHECK
(
s
.
defined
());
ICHECK
(
s
.
has_value
());
std
::
string
arch
=
s
.
value
();
std
::
string
arch
=
s
.
value
();
if
(
arch
.
rfind
(
"sm_"
,
0
)
==
0
)
{
if
(
arch
.
rfind
(
"sm_"
,
0
)
==
0
)
{
arch_int
=
std
::
stoi
(
arch
.
substr
(
3
));
arch_int
=
std
::
stoi
(
arch
.
substr
(
3
));
...
@@ -502,56 +530,61 @@ static int GetArchInt(Target target) {
...
@@ -502,56 +530,61 @@ static int GetArchInt(Target target) {
*/
*/
Stmt
GemmNode
::
Lower
(
const
LowerArgs
&
T
,
arith
::
Analyzer
*
analyzer
)
const
{
Stmt
GemmNode
::
Lower
(
const
LowerArgs
&
T
,
arith
::
Analyzer
*
analyzer
)
const
{
auto
block_size
=
*
as_const_int
(
T
.
thread_bounds
->
extent
);
auto
block_size
=
*
as_const_int
(
T
.
thread_bounds
->
extent
);
GemmInst
gemm_inst
=
G
etGemmInst
(
block_size
,
T
.
target
);
GemmInst
gemm_inst
=
g
etGemmInst
(
block_size
,
T
.
target
);
auto
[
warp_m
,
warp_n
]
=
auto
[
warp_m
,
warp_n
]
=
policy
->
ComputeWarpPartition
(
M
,
N
,
block_size
,
T
.
target
,
gemm_inst
);
policy_
->
computeWarpPartition
(
m_
,
n_
,
block_size
,
T
.
target
,
gemm_inst
);
// Build access pointers from regions locally
PrimExpr
Aptr
=
MakeAccessPtrFromRegion
(
aRegion_
,
/*r*/
1
);
PrimExpr
Bptr
=
MakeAccessPtrFromRegion
(
bRegion_
,
/*r*/
1
);
PrimExpr
Cptr
=
MakeAccessPtrFromRegion
(
cRegion_
,
/*rw*/
3
);
std
::
stringstream
ss
;
std
::
stringstream
ss
;
std
::
string
op_name
;
std
::
string
op_name
;
if
(
gemm_inst
==
GemmInst
::
kTCGEN5MMA
)
{
if
(
gemm_inst
==
GemmInst
::
kTCGEN5MMA
)
{
auto
[
can_use_tcgen5mma
,
meta
]
=
auto
[
can_use_tcgen5mma
,
meta
]
=
GetTCGEN5MMAMeta
(
M
,
N
,
K
,
A
->
dtype
,
C
->
dtype
);
GetTCGEN5MMAMeta
(
m_
,
n_
,
k_
,
a_
->
dtype
,
c_
->
dtype
);
ICHECK
(
can_use_tcgen5mma
);
ICHECK
(
can_use_tcgen5mma
);
ICHECK
(
B
.
scope
()
==
"shared.dyn"
||
B
.
scope
()
==
"shared"
);
ICHECK
(
b_
.
scope
()
==
"shared.dyn"
||
b_
.
scope
()
==
"shared"
);
ICHECK
(
C
.
scope
()
==
"shared.tmem"
);
ICHECK
(
c_
.
scope
()
==
"shared.tmem"
);
ICHECK
(
mbar
.
has_value
())
<<
"mbar must be provided for TCGEN5MMA"
;
ICHECK
(
mbar
_
.
has_value
())
<<
"mbar must be provided for TCGEN5MMA"
;
if
(
A
.
scope
()
==
"shared.tmem"
)
{
if
(
a_
.
scope
()
==
"shared.tmem"
)
{
op_name
=
"tl::tcgen5mma_gemm_ts"
;
op_name
=
"tl::tcgen5mma_gemm_ts"
;
}
else
if
(
A
.
scope
()
==
"shared.dyn"
||
A
.
scope
()
==
"shared"
)
{
}
else
if
(
a_
.
scope
()
==
"shared.dyn"
||
a_
.
scope
()
==
"shared"
)
{
op_name
=
"tl::tcgen5mma_gemm_ss"
;
op_name
=
"tl::tcgen5mma_gemm_ss"
;
}
else
{
}
else
{
ICHECK
(
0
)
ICHECK
(
0
)
<<
"Unsupported A scope for TCGEN5MMA: "
<<
"Unsupported A scope for TCGEN5MMA: "
<<
A
.
scope
();
// If this is triggered, it means Tilelang has bugs.
<<
a_
.
scope
();
// If this is triggered, it means Tilelang has bugs.
}
}
ICHECK
(
wg
_w
ait
==
-
1
)
ICHECK
(
wg
W
ait
_
==
-
1
)
<<
"Currently only wg_wait == -1 is supported for TCGEN5MMA. Please "
<<
"Currently only wg_wait == -1 is supported for TCGEN5MMA. Please "
"use "
"use "
"wg_wait = -1 and manually synchronize with mbarrier."
;
"wg_wait = -1 and manually synchronize with mbarrier."
;
std
::
string
accum_dtype
=
""
;
std
::
string
accum_dtype
=
""
;
if
(
C
->
dtype
.
is_float
())
{
if
(
c_
->
dtype
.
is_float
())
{
if
(
C
->
dtype
.
bits
()
==
32
)
{
if
(
c_
->
dtype
.
bits
()
==
32
)
{
accum_dtype
=
"float"
;
accum_dtype
=
"float"
;
}
}
}
}
ICHECK
(
!
accum_dtype
.
empty
())
ICHECK
(
!
accum_dtype
.
empty
())
<<
"Unsupported C dtype for TCGEN5MMA: "
<<
C
->
dtype
;
<<
"Unsupported C dtype for TCGEN5MMA: "
<<
c_
->
dtype
;
ss
<<
op_name
<<
"<"
<<
M
<<
", "
<<
N
<<
", "
<<
K
<<
", "
;
ss
<<
op_name
<<
"<"
<<
m_
<<
", "
<<
n_
<<
", "
<<
k_
<<
", "
;
ss
<<
meta
.
atom_m
<<
", "
<<
meta
.
atom_n
<<
", "
<<
meta
.
atom_k
<<
", "
;
ss
<<
meta
.
atom_m
<<
", "
<<
meta
.
atom_n
<<
", "
<<
meta
.
atom_k
<<
", "
;
ss
<<
trans
_
A
<<
", "
<<
trans
_
B
<<
", "
;
ss
<<
transA
_
<<
", "
<<
transB
_
<<
", "
;
ss
<<
accum_dtype
;
ss
<<
accum_dtype
;
ss
<<
">"
;
ss
<<
">"
;
auto
C_buffer
=
T
.
buffer_remap
.
count
(
C
)
?
T
.
buffer_remap
[
C
]
:
C
;
auto
C_buffer
=
T
.
buffer_remap
.
count
(
c_
)
?
T
.
buffer_remap
[
c_
]
:
c_
;
Array
<
PrimExpr
>
new_args
;
Array
<
PrimExpr
>
new_args
;
new_args
.
push_back
(
StringImm
(
ss
.
str
()));
new_args
.
push_back
(
StringImm
(
ss
.
str
()));
new_args
.
push_back
(
Aptr
);
new_args
.
push_back
(
Aptr
);
new_args
.
push_back
(
Bptr
);
new_args
.
push_back
(
Bptr
);
new_args
.
push_back
(
BufferLoad
(
C_buffer
,
C_
coords
));
new_args
.
push_back
(
BufferLoad
(
C_buffer
,
c
C
oords
_
));
new_args
.
push_back
(
mbar
p
tr
);
new_args
.
push_back
(
mbar
P
tr
_
);
new_args
.
push_back
(
clear
_a
ccum
);
new_args
.
push_back
(
clear
A
ccum
_
);
auto
new_call
=
Call
(
DataType
::
Handle
(),
builtin
::
call_extern
(),
new_args
);
auto
new_call
=
Call
(
DataType
::
Handle
(),
builtin
::
call_extern
(),
new_args
);
// Since TCGEN5MMA atoms provided by CUTLASS always have an internal
// Since TCGEN5MMA atoms provided by CUTLASS always have an internal
...
@@ -576,47 +609,49 @@ Stmt GemmNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
...
@@ -576,47 +609,49 @@ Stmt GemmNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
}
}
}
}
if
(
A
.
scope
()
==
"local.fragment"
)
{
if
(
a_
.
scope
()
==
"local.fragment"
)
{
ICHECK
(
B
.
scope
()
!=
"local.fragment"
);
ICHECK
(
b_
.
scope
()
!=
"local.fragment"
);
ICHECK
(
!
transA_
)
<<
"gemm_rs requires the A operand to be in non-transposed layout."
;
op_name
=
"tl::gemm_rs"
;
op_name
=
"tl::gemm_rs"
;
}
else
if
(
B
.
scope
()
==
"local.fragment"
)
{
}
else
if
(
b_
.
scope
()
==
"local.fragment"
)
{
op_name
=
"tl::gemm_sr"
;
op_name
=
"tl::gemm_sr"
;
}
else
{
}
else
{
op_name
=
"tl::gemm_ss"
;
op_name
=
"tl::gemm_ss"
;
}
}
ICHECK
(
C
.
scope
()
==
"local.fragment"
);
ICHECK
(
c_
.
scope
()
==
"local.fragment"
);
ss
<<
op_name
<<
"<"
<<
M
<<
", "
<<
N
<<
", "
<<
K
<<
", "
;
ss
<<
op_name
<<
"<"
<<
m_
<<
", "
<<
n_
<<
", "
<<
k_
<<
", "
;
ss
<<
warp_m
<<
", "
<<
warp_n
<<
", "
;
ss
<<
warp_m
<<
", "
<<
warp_n
<<
", "
;
ss
<<
trans
_
A
<<
", "
<<
trans
_
B
;
ss
<<
transA
_
<<
", "
<<
transB
_
;
auto
clear_accum_bool
=
clear
_a
ccum
.
as
<
Bool
>
();
auto
clear_accum_bool
=
clear
A
ccum
_
.
as
<
Bool
>
();
ICHECK
(
clear_accum_bool
.
has_value
())
ICHECK
(
clear_accum_bool
.
has_value
())
<<
"clear_accum must be a constant Bool type, got "
<<
clear
_a
ccum
;
<<
"clear_accum must be a constant Bool type, got "
<<
clear
A
ccum
_
;
ss
<<
", "
<<
bool
(
clear_accum_bool
.
value
());
ss
<<
", "
<<
bool
(
clear_accum_bool
.
value
());
if
(
TargetIsCuda
(
T
.
target
)
&&
(
GetArchInt
(
T
.
target
)
>=
75
))
{
if
(
TargetIsCuda
(
T
.
target
)
&&
(
GetArchInt
(
T
.
target
)
>=
75
))
{
ss
<<
", "
<<
stride
_
A
<<
", "
<<
stride
_
B
;
ss
<<
", "
<<
strideA
_
<<
", "
<<
strideB
_
;
ss
<<
", "
<<
offset
_
A
<<
", "
<<
offset
_
B
;
ss
<<
", "
<<
offsetA
_
<<
", "
<<
offsetB
_
;
}
}
if
(
TargetIsCDNA
(
T
.
target
))
{
if
(
TargetIsCDNA
(
T
.
target
))
{
// for cdna gemm, we need to specify kPack
// for cdna gemm, we need to specify kPack
ss
<<
", "
<<
kPack
;
ss
<<
", "
<<
kPack
_
;
}
else
if
(
TargetIsHopper
(
T
.
target
))
{
}
else
if
(
TargetIsHopper
(
T
.
target
))
{
ss
<<
", "
<<
(
gemm_inst
==
GemmInst
::
kWGMMA
?
"true"
:
"false"
);
ss
<<
", "
<<
(
gemm_inst
==
GemmInst
::
kWGMMA
?
"true"
:
"false"
);
}
}
// Emit wg_wait if necessary
// Emit wg_wait if necessary
if
(
TargetIsHopper
(
T
.
target
))
{
if
(
TargetIsHopper
(
T
.
target
))
{
if
(
wg
_w
ait
!=
0
)
{
if
(
wg
W
ait
_
!=
0
)
{
ss
<<
", "
<<
wg
_w
ait
;
ss
<<
", "
<<
wg
W
ait
_
;
}
}
}
else
if
(
TargetIsSm100
(
T
.
target
))
{
}
else
if
(
TargetIsSm100
(
T
.
target
))
{
// NOTE On sm100, only the leading thread issues the TCGEN5MMA instruction
// NOTE On sm100, only the leading thread issues the TCGEN5MMA instruction
// but all threads need to wait, so we emit another statement for cases
// but all threads need to wait, so we emit another statement for cases
// where wg_wait == 0.
// where wg_wait == 0.
ICHECK
(
wg
_w
ait
==
0
||
wg
_w
ait
==
-
1
)
ICHECK
(
wg
W
ait
_
==
0
||
wg
W
ait
_
==
-
1
)
<<
"wg_wait must be 0 or -1 for Sm100"
;
<<
"wg_wait must be 0 or -1 for Sm100"
;
}
else
{
}
else
{
ICHECK
(
wg
_w
ait
==
0
)
ICHECK
(
wg
W
ait
_
==
0
)
<<
"wg_wait must be 0 for non-Hopper and non-Sm100 targets"
;
<<
"wg_wait must be 0 for non-Hopper and non-Sm100 targets"
;
}
}
ss
<<
">"
;
ss
<<
">"
;
...
@@ -652,151 +687,152 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T,
...
@@ -652,151 +687,152 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T,
LayoutMap
results
;
LayoutMap
results
;
auto
thread_range
=
T
.
thread_bounds
;
auto
thread_range
=
T
.
thread_bounds
;
auto
block_size
=
*
as_const_int
(
thread_range
->
extent
);
auto
block_size
=
*
as_const_int
(
thread_range
->
extent
);
GemmInst
gemm_inst
=
G
etGemmInst
(
block_size
,
T
.
target
);
GemmInst
gemm_inst
=
g
etGemmInst
(
block_size
,
T
.
target
);
auto
[
warp_m
,
warp_n
]
=
auto
[
warp_m
,
warp_n
]
=
policy
->
C
omputeWarpPartition
(
M
,
N
,
block_size
,
T
.
target
,
gemm_inst
);
policy
_
->
c
omputeWarpPartition
(
m_
,
n_
,
block_size
,
T
.
target
,
gemm_inst
);
if
(
TargetIsVolta
(
T
.
target
))
{
if
(
TargetIsVolta
(
T
.
target
))
{
ICHECK
(
C
.
scope
()
==
"local.fragment"
)
ICHECK
(
c_
.
scope
()
==
"local.fragment"
)
<<
"Volta gemm only supports C in local.fragment scope, got "
<<
"Volta gemm only supports C in local.fragment scope, got "
<<
C
.
scope
();
<<
c_
.
scope
();
auto
fragment
=
auto
fragment
=
makeGemmVoltaFragmentC
(
m_
,
n_
,
m_
/
warp_m
,
n_
/
warp_n
,
makeGemmVoltaFragmentC
(
M
,
N
,
M
/
warp_m
,
N
/
warp_n
,
C
->
dtype
.
bits
());
c_
->
dtype
.
bits
());
results
.
Set
(
C
,
fragment
->
BindThreadRange
(
thread_range
));
results
.
Set
(
c_
,
fragment
->
BindThreadRange
(
thread_range
));
if
(
A
.
scope
()
==
"shared"
||
A
.
scope
()
==
"shared.dyn"
)
{
if
(
a_
.
scope
()
==
"shared"
||
a_
.
scope
()
==
"shared.dyn"
)
{
int
dim_A
=
A
->
shape
.
size
();
int
dim_A
=
a_
->
shape
.
size
();
results
.
Set
(
A
,
makeGemmVoltaABLayout
(
*
as_const_int
(
A
->
shape
[
dim_A
-
2
]),
results
.
Set
(
a_
,
makeGemmVoltaABLayout
(
*
as_const_int
(
a_
->
shape
[
dim_A
-
2
]),
*
as_const_int
(
A
->
shape
[
dim_A
-
1
]),
*
as_const_int
(
a_
->
shape
[
dim_A
-
1
]),
true
,
!
trans_A
));
true
,
!
transA_
));
}
else
if
(
A
.
scope
()
==
"local.fragment"
)
{
}
else
if
(
a_
.
scope
()
==
"local.fragment"
)
{
ICHECK
(
trans_A
==
false
);
ICHECK
(
transA_
==
false
);
auto
fragment
=
makeGemmVoltaFragmentA
(
M
,
N
,
K
,
M
/
warp_m
,
N
/
warp_n
);
auto
fragment
=
results
.
Set
(
A
,
fragment
->
BindThreadRange
(
thread_range
));
makeGemmVoltaFragmentA
(
m_
,
n_
,
k_
,
m_
/
warp_m
,
n_
/
warp_n
);
results
.
Set
(
a_
,
fragment
->
BindThreadRange
(
thread_range
));
}
else
{
}
else
{
ICHECK
(
0
);
ICHECK
(
0
);
}
}
ICHECK
(
B
.
scope
()
==
"shared"
||
B
.
scope
()
==
"shared.dyn"
);
ICHECK
(
b_
.
scope
()
==
"shared"
||
b_
.
scope
()
==
"shared.dyn"
);
int
dim_B
=
B
->
shape
.
size
();
int
dim_B
=
b_
->
shape
.
size
();
results
.
Set
(
B
,
makeGemmVoltaABLayout
(
*
as_const_int
(
B
->
shape
[
dim_B
-
2
]),
results
.
Set
(
b_
,
makeGemmVoltaABLayout
(
*
as_const_int
(
b_
->
shape
[
dim_B
-
2
]),
*
as_const_int
(
B
->
shape
[
dim_B
-
1
]),
*
as_const_int
(
b_
->
shape
[
dim_B
-
1
]),
false
,
trans
_
B
));
false
,
transB
_
));
}
else
if
(
TargetIsAmpere
(
T
.
target
)
||
TargetIsTuring
(
T
.
target
)
||
}
else
if
(
TargetIsAmpere
(
T
.
target
)
||
TargetIsTuring
(
T
.
target
)
||
TargetIsSM120
(
T
.
target
)
||
TargetIsSM120
(
T
.
target
)
||
(
TargetIsSm100
(
T
.
target
)
&&
gemm_inst
==
GemmInst
::
kMMA
))
{
(
TargetIsSm100
(
T
.
target
)
&&
gemm_inst
==
GemmInst
::
kMMA
))
{
ICHECK
(
C
.
scope
()
==
"local.fragment"
)
ICHECK
(
c_
.
scope
()
==
"local.fragment"
)
<<
"MMA only supports C in local.fragment scope, got "
<<
C
.
scope
();
<<
"MMA only supports C in local.fragment scope, got "
<<
c_
.
scope
();
auto
fragment
=
auto
fragment
=
makeGemmFragmentC
(
M
,
N
,
M
/
warp_m
,
N
/
warp_n
,
C
->
dtype
.
bits
());
makeGemmFragmentC
(
m_
,
n_
,
m_
/
warp_m
,
n_
/
warp_n
,
c_
->
dtype
.
bits
());
results
.
Set
(
C
,
fragment
->
BindThreadRange
(
thread_range
));
results
.
Set
(
c_
,
fragment
->
BindThreadRange
(
thread_range
));
if
(
A
.
scope
()
==
"shared"
||
A
.
scope
()
==
"shared.dyn"
)
{
if
(
a_
.
scope
()
==
"shared"
||
a_
.
scope
()
==
"shared.dyn"
)
{
int
dim_A
=
A
->
shape
.
size
();
int
dim_A
=
a_
->
shape
.
size
();
const
int64_t
mat_stride
=
*
as_const_int
(
A
->
shape
[
dim_A
-
2
]);
const
int64_t
mat_stride
=
*
as_const_int
(
a_
->
shape
[
dim_A
-
2
]);
const
int64_t
mat_continuous
=
*
as_const_int
(
A
->
shape
[
dim_A
-
1
]);
const
int64_t
mat_continuous
=
*
as_const_int
(
a_
->
shape
[
dim_A
-
1
]);
results
.
Set
(
A
,
results
.
Set
(
a_
,
makeGemmABLayout
(
mat_stride
,
mat_continuous
,
mat_continuous
,
makeGemmABLayout
(
mat_stride
,
mat_continuous
,
mat_continuous
,
A
->
dtype
.
bits
(),
!
trans
_
A
));
a_
->
dtype
.
bits
(),
!
transA
_
));
}
else
if
(
A
.
scope
()
==
"local.fragment"
)
{
}
else
if
(
a_
.
scope
()
==
"local.fragment"
)
{
auto
fragment
=
makeGemmFragmentA
(
M
,
N
,
K
,
M
/
warp_m
,
N
/
warp_n
,
auto
fragment
=
makeGemmFragmentA
(
m_
,
n_
,
k_
,
m_
/
warp_m
,
n_
/
warp_n
,
A
->
dtype
.
bits
(),
trans
_
A
);
a_
->
dtype
.
bits
(),
transA
_
);
results
.
Set
(
A
,
fragment
->
BindThreadRange
(
thread_range
));
results
.
Set
(
a_
,
fragment
->
BindThreadRange
(
thread_range
));
}
else
{
}
else
{
ICHECK
(
0
);
ICHECK
(
0
);
}
}
if
(
B
.
scope
()
==
"shared"
||
B
.
scope
()
==
"shared.dyn"
)
{
if
(
b_
.
scope
()
==
"shared"
||
b_
.
scope
()
==
"shared.dyn"
)
{
int
dim_B
=
B
->
shape
.
size
();
int
dim_B
=
b_
->
shape
.
size
();
const
int64_t
mat_stride
=
*
as_const_int
(
B
->
shape
[
dim_B
-
2
]);
const
int64_t
mat_stride
=
*
as_const_int
(
b_
->
shape
[
dim_B
-
2
]);
const
int64_t
mat_continuous
=
*
as_const_int
(
B
->
shape
[
dim_B
-
1
]);
const
int64_t
mat_continuous
=
*
as_const_int
(
b_
->
shape
[
dim_B
-
1
]);
results
.
Set
(
B
,
results
.
Set
(
b_
,
makeGemmABLayout
(
mat_stride
,
mat_continuous
,
mat_continuous
,
makeGemmABLayout
(
mat_stride
,
mat_continuous
,
mat_continuous
,
B
->
dtype
.
bits
(),
trans
_
B
));
b_
->
dtype
.
bits
(),
transB
_
));
}
else
if
(
B
.
scope
()
==
"local.fragment"
)
{
}
else
if
(
b_
.
scope
()
==
"local.fragment"
)
{
auto
fragment
=
auto
fragment
=
makeGemmFragmentB
(
M
,
N
,
K
,
M
/
warp_m
,
N
/
warp_n
,
trans
_
B
);
makeGemmFragmentB
(
m_
,
n_
,
k_
,
m_
/
warp_m
,
n_
/
warp_n
,
transB
_
);
results
.
Set
(
B
,
fragment
->
BindThreadRange
(
thread_range
));
results
.
Set
(
b_
,
fragment
->
BindThreadRange
(
thread_range
));
}
else
{
}
else
{
ICHECK
(
0
);
ICHECK
(
0
);
}
}
}
else
if
(
TargetIsHopper
(
T
.
target
))
{
}
else
if
(
TargetIsHopper
(
T
.
target
))
{
ICHECK
(
C
.
scope
()
==
"local.fragment"
)
ICHECK
(
c_
.
scope
()
==
"local.fragment"
)
<<
(
gemm_inst
==
GemmInst
::
kWGMMA
?
"WGMMA "
:
"MMA "
)
<<
(
gemm_inst
==
GemmInst
::
kWGMMA
?
"WGMMA "
:
"MMA "
)
<<
"only supports C in local.fragment scope, got "
<<
C
.
scope
();
<<
"only supports C in local.fragment scope, got "
<<
c_
.
scope
();
auto
fragment
=
auto
fragment
=
gemm_inst
==
GemmInst
::
kWGMMA
gemm_inst
==
GemmInst
::
kWGMMA
?
makeGemmFragmentCHopper
(
m_
,
n_
,
m_
/
warp_m
,
?
makeGemmFragmentCHopper
(
M
,
N
,
M
/
warp_m
,
N
/
warp_n
,
n_
/
warp_n
,
c_
->
dtype
.
bits
())
C
->
dtype
.
bits
())
:
makeGemmFragmentC
(
m_
,
n_
,
m_
/
warp_m
,
n_
/
warp_n
,
:
makeGemmFragmentC
(
M
,
N
,
M
/
warp_m
,
N
/
warp_n
,
C
->
dtype
.
bits
());
c_
->
dtype
.
bits
());
results
.
Set
(
C
,
fragment
->
BindThreadRange
(
thread_range
));
results
.
Set
(
c_
,
fragment
->
BindThreadRange
(
thread_range
));
if
(
A
.
scope
()
==
"shared"
||
A
.
scope
()
==
"shared.dyn"
)
{
if
(
a_
.
scope
()
==
"shared"
||
a_
.
scope
()
==
"shared.dyn"
)
{
int
dim_A
=
A
->
shape
.
size
();
int
dim_A
=
a_
->
shape
.
size
();
const
int64_t
mat_stride
=
*
as_const_int
(
A
->
shape
[
dim_A
-
2
]);
const
int64_t
mat_stride
=
*
as_const_int
(
a_
->
shape
[
dim_A
-
2
]);
const
int64_t
mat_continuous
=
*
as_const_int
(
A
->
shape
[
dim_A
-
1
]);
const
int64_t
mat_continuous
=
*
as_const_int
(
a_
->
shape
[
dim_A
-
1
]);
const
int64_t
continuity
=
const
int64_t
continuity
=
trans
_
A
?
4
*
mat_continuous
/
warp_m
:
mat_continuous
;
transA
_
?
4
*
mat_continuous
/
warp_m
:
mat_continuous
;
auto
ABLayout
=
auto
ABLayout
=
gemm_inst
==
GemmInst
::
kWGMMA
gemm_inst
==
GemmInst
::
kWGMMA
?
makeGemmABLayoutHopper
(
mat_stride
,
mat_continuous
,
continuity
,
?
makeGemmABLayoutHopper
(
mat_stride
,
mat_continuous
,
continuity
,
A
->
dtype
.
bits
(),
!
trans
_
A
)
a_
->
dtype
.
bits
(),
!
transA
_
)
:
makeGemmABLayout
(
mat_stride
,
mat_continuous
,
mat_continuous
,
:
makeGemmABLayout
(
mat_stride
,
mat_continuous
,
mat_continuous
,
A
->
dtype
.
bits
(),
!
trans
_
A
);
a_
->
dtype
.
bits
(),
!
transA
_
);
results
.
Set
(
A
,
ABLayout
);
results
.
Set
(
a_
,
ABLayout
);
}
else
{
}
else
{
auto
fragment
=
makeGemmFragmentA
(
M
,
N
,
K
,
M
/
warp_m
,
N
/
warp_n
,
auto
fragment
=
makeGemmFragmentA
(
m_
,
n_
,
k_
,
m_
/
warp_m
,
n_
/
warp_n
,
A
->
dtype
.
bits
(),
trans
_
A
);
a_
->
dtype
.
bits
(),
transA
_
);
results
.
Set
(
A
,
fragment
->
BindThreadRange
(
thread_range
));
results
.
Set
(
a_
,
fragment
->
BindThreadRange
(
thread_range
));
}
}
if
(
B
.
scope
()
==
"shared"
||
B
.
scope
()
==
"shared.dyn"
)
{
if
(
b_
.
scope
()
==
"shared"
||
b_
.
scope
()
==
"shared.dyn"
)
{
int
dim_B
=
B
->
shape
.
size
();
int
dim_B
=
b_
->
shape
.
size
();
const
int64_t
mat_stride
=
*
as_const_int
(
B
->
shape
[
dim_B
-
2
]);
const
int64_t
mat_stride
=
*
as_const_int
(
b_
->
shape
[
dim_B
-
2
]);
const
int64_t
mat_continuous
=
*
as_const_int
(
B
->
shape
[
dim_B
-
1
]);
const
int64_t
mat_continuous
=
*
as_const_int
(
b_
->
shape
[
dim_B
-
1
]);
const
int64_t
continuity
=
const
int64_t
continuity
=
trans
_
B
?
mat_continuous
:
mat_continuous
/
warp_n
;
transB
_
?
mat_continuous
:
mat_continuous
/
warp_n
;
auto
ABLayout
=
auto
ABLayout
=
gemm_inst
==
GemmInst
::
kWGMMA
gemm_inst
==
GemmInst
::
kWGMMA
?
makeGemmABLayoutHopper
(
mat_stride
,
mat_continuous
,
continuity
,
?
makeGemmABLayoutHopper
(
mat_stride
,
mat_continuous
,
continuity
,
B
->
dtype
.
bits
(),
trans
_
B
)
b_
->
dtype
.
bits
(),
transB
_
)
:
makeGemmABLayout
(
mat_stride
,
mat_continuous
,
mat_continuous
,
:
makeGemmABLayout
(
mat_stride
,
mat_continuous
,
mat_continuous
,
B
->
dtype
.
bits
(),
trans
_
B
);
b_
->
dtype
.
bits
(),
transB
_
);
results
.
Set
(
B
,
ABLayout
);
results
.
Set
(
b_
,
ABLayout
);
}
else
{
}
else
{
auto
fragment
=
auto
fragment
=
makeGemmFragmentB
(
M
,
N
,
K
,
M
/
warp_m
,
N
/
warp_n
,
trans
_
B
);
makeGemmFragmentB
(
m_
,
n_
,
k_
,
m_
/
warp_m
,
n_
/
warp_n
,
transB
_
);
results
.
Set
(
B
,
fragment
->
BindThreadRange
(
thread_range
));
results
.
Set
(
b_
,
fragment
->
BindThreadRange
(
thread_range
));
}
}
}
else
if
(
gemm_inst
==
GemmInst
::
kTCGEN5MMA
)
{
}
else
if
(
gemm_inst
==
GemmInst
::
kTCGEN5MMA
)
{
ICHECK
(
C
.
scope
()
==
"shared.tmem"
)
ICHECK
(
c_
.
scope
()
==
"shared.tmem"
)
<<
"TCGEN5MMA only supports C in shared.tmem scope, got "
<<
C
.
scope
();
<<
"TCGEN5MMA only supports C in shared.tmem scope, got "
<<
c_
.
scope
();
ICHECK
(
A
.
scope
()
==
"shared.dyn"
||
A
.
scope
()
==
"shared"
)
ICHECK
(
a_
.
scope
()
==
"shared.dyn"
||
a_
.
scope
()
==
"shared"
)
<<
"Current TCGEN5MMA only supports A in shared.dyn scope"
;
<<
"Current TCGEN5MMA only supports A in shared.dyn scope"
;
auto
[
can_use_tcgen5mma
,
meta
]
=
auto
[
can_use_tcgen5mma
,
meta
]
=
GetTCGEN5MMAMeta
(
M
,
N
,
K
,
A
->
dtype
,
C
->
dtype
);
GetTCGEN5MMAMeta
(
m_
,
n_
,
k_
,
a_
->
dtype
,
c_
->
dtype
);
ICHECK
(
can_use_tcgen5mma
);
ICHECK
(
can_use_tcgen5mma
);
{
{
int
dim_A
=
A
->
shape
.
size
();
int
dim_A
=
a_
->
shape
.
size
();
const
int64_t
mat_stride
=
*
as_const_int
(
A
->
shape
[
dim_A
-
2
]);
const
int64_t
mat_stride
=
*
as_const_int
(
a_
->
shape
[
dim_A
-
2
]);
const
int64_t
mat_continuous
=
*
as_const_int
(
A
->
shape
[
dim_A
-
1
]);
const
int64_t
mat_continuous
=
*
as_const_int
(
a_
->
shape
[
dim_A
-
1
]);
results
.
Set
(
A
,
makeGemmABLayoutSm100
(
mat_stride
,
mat_continuous
,
results
.
Set
(
a_
,
makeGemmABLayoutSm100
(
mat_stride
,
mat_continuous
,
mat_continuous
,
A
->
dtype
.
bits
(),
mat_continuous
,
a_
->
dtype
.
bits
(),
trans
_
A
?
1
:
2
));
transA
_
?
1
:
2
));
}
}
{
{
int
dim_B
=
B
->
shape
.
size
();
int
dim_B
=
b_
->
shape
.
size
();
const
int64_t
mat_stride
=
*
as_const_int
(
B
->
shape
[
dim_B
-
2
]);
const
int64_t
mat_stride
=
*
as_const_int
(
b_
->
shape
[
dim_B
-
2
]);
const
int64_t
mat_continuous
=
*
as_const_int
(
B
->
shape
[
dim_B
-
1
]);
const
int64_t
mat_continuous
=
*
as_const_int
(
b_
->
shape
[
dim_B
-
1
]);
const
int64_t
continuity
=
mat_continuous
;
const
int64_t
continuity
=
mat_continuous
;
results
.
Set
(
B
,
results
.
Set
(
b_
,
makeGemmABLayoutSm100
(
mat_stride
,
mat_continuous
,
continuity
,
makeGemmABLayoutSm100
(
mat_stride
,
mat_continuous
,
continuity
,
B
->
dtype
.
bits
(),
trans
_
B
?
2
:
1
));
b_
->
dtype
.
bits
(),
transB
_
?
2
:
1
));
}
}
{
{
Layout
res
;
Layout
res
;
IterVar
i
=
make_itervar
(
"i"
,
M
);
IterVar
i
=
make_itervar
(
"i"
,
m_
);
IterVar
j
=
make_itervar
(
"j"
,
N
);
IterVar
j
=
make_itervar
(
"j"
,
n_
);
ICHECK
(
M
%
meta
.
atom_m
==
0
);
ICHECK
(
m_
%
meta
.
atom_m
==
0
);
PrimExpr
atom_idx
=
FloorDiv
(
i
,
meta
.
atom_m
)
+
PrimExpr
atom_idx
=
FloorDiv
(
i
,
meta
.
atom_m
)
+
FloorDiv
(
j
,
meta
.
atom_n
)
*
(
M
/
meta
.
atom_m
);
FloorDiv
(
j
,
meta
.
atom_n
)
*
(
m_
/
meta
.
atom_m
);
PrimExpr
ai
=
FloorMod
(
i
,
meta
.
atom_m
);
// "ai" means "atom_i"
PrimExpr
ai
=
FloorMod
(
i
,
meta
.
atom_m
);
// "ai" means "atom_i"
PrimExpr
aj
=
FloorMod
(
j
,
meta
.
atom_n
);
PrimExpr
aj
=
FloorMod
(
j
,
meta
.
atom_n
);
if
(
meta
.
atom_m
==
128
)
{
if
(
meta
.
atom_m
==
128
)
{
...
@@ -822,46 +858,46 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T,
...
@@ -822,46 +858,46 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T,
}
else
{
}
else
{
ICHECK
(
0
);
ICHECK
(
0
);
}
}
results
.
Set
(
C
,
res
);
results
.
Set
(
c_
,
res
);
}
}
}
else
if
(
TargetIsCDNA
(
T
.
target
))
{
}
else
if
(
TargetIsCDNA
(
T
.
target
))
{
ICHECK
(
C
.
scope
()
==
"local.fragment"
)
ICHECK
(
c_
.
scope
()
==
"local.fragment"
)
<<
"CDNA gemm (FMMA) only supports C in local.fragment scope, got "
<<
"CDNA gemm (FMMA) only supports C in local.fragment scope, got "
<<
C
.
scope
();
<<
c_
.
scope
();
if
(
TargetIsDCU
(
T
.
target
))
{
if
(
TargetIsDCU
(
T
.
target
))
{
auto
fragment
=
auto
fragment
=
makeGemmFragmentCDCU
(
M
,
N
,
M
/
warp_m
,
N
/
warp_n
,
C
->
dtype
.
bits
());
makeGemmFragmentCDCU
(
m_
,
n_
,
m_
/
warp_m
,
n_
/
warp_n
,
c_
->
dtype
.
bits
());
results
.
Set
(
C
,
fragment
->
BindThreadRange
(
thread_range
));
results
.
Set
(
c_
,
fragment
->
BindThreadRange
(
thread_range
));
}
else
{
}
else
{
auto
fragment
=
auto
fragment
=
makeGemmFragmentCCDNA
(
M
,
N
,
M
/
warp_m
,
N
/
warp_n
,
C
->
dtype
.
bits
());
makeGemmFragmentCCDNA
(
m_
,
n_
,
m_
/
warp_m
,
n_
/
warp_n
,
c_
->
dtype
.
bits
());
results
.
Set
(
C
,
fragment
->
BindThreadRange
(
thread_range
));
results
.
Set
(
c_
,
fragment
->
BindThreadRange
(
thread_range
));
}
}
if
(
a_
.
scope
()
==
"shared"
||
a_
.
scope
()
==
"shared.dyn"
)
{
if
(
A
.
scope
()
==
"shared"
||
A
.
scope
()
==
"shared.dyn"
)
{
int
dim_A
=
a_
->
shape
.
size
();
int
dim_A
=
A
->
shape
.
size
();
auto
shared_layout
=
makeGemmABLayoutCDNA
(
auto
shared_layout
=
makeGemmABLayoutCDNA
(
*
as_const_int
(
A
->
shape
[
dim_A
-
2
]),
*
as_const_int
(
a_
->
shape
[
dim_A
-
2
]),
*
as_const_int
(
A
->
shape
[
dim_A
-
1
]),
A
->
dtype
.
bits
(),
kPack
);
*
as_const_int
(
a_
->
shape
[
dim_A
-
1
]),
a_
->
dtype
.
bits
(),
kPack_
);
results
.
Set
(
A
,
shared_layout
);
results
.
Set
(
a_
,
shared_layout
);
}
else
if
(
A
.
scope
()
==
"local.fragment"
)
{
}
else
if
(
a_
.
scope
()
==
"local.fragment"
)
{
auto
fragment
=
makeGemmFragmentACDNA
(
M
,
N
,
K
,
M
/
warp_m
,
N
/
warp_n
,
auto
fragment
=
A
->
dtype
.
bits
(),
kPack
,
trans_A
);
makeGemmFragmentACDNA
(
m_
,
n_
,
k_
,
m_
/
warp_m
,
n_
/
warp_n
,
results
.
Set
(
A
,
fragment
->
BindThreadRange
(
thread_range
));
a_
->
dtype
.
bits
(),
kPack_
,
transA_
);
results
.
Set
(
a_
,
fragment
->
BindThreadRange
(
thread_range
));
}
else
{
}
else
{
ICHECK
(
0
);
ICHECK
(
0
);
}
}
if
(
B
.
scope
()
==
"shared"
||
B
.
scope
()
==
"shared.dyn"
)
{
if
(
b_
.
scope
()
==
"shared"
||
b_
.
scope
()
==
"shared.dyn"
)
{
int
dim_B
=
B
->
shape
.
size
();
int
dim_B
=
b_
->
shape
.
size
();
auto
shared_layout
=
makeGemmABLayoutCDNA
(
auto
shared_layout
=
makeGemmABLayoutCDNA
(
*
as_const_int
(
B
->
shape
[
dim_B
-
2
]),
*
as_const_int
(
b_
->
shape
[
dim_B
-
2
]),
*
as_const_int
(
B
->
shape
[
dim_B
-
1
]),
B
->
dtype
.
bits
(),
kPack
);
*
as_const_int
(
b_
->
shape
[
dim_B
-
1
]),
b_
->
dtype
.
bits
(),
kPack
_
);
results
.
Set
(
B
,
shared_layout
);
results
.
Set
(
b_
,
shared_layout
);
}
else
if
(
B
.
scope
()
==
"local.fragment"
)
{
}
else
if
(
b_
.
scope
()
==
"local.fragment"
)
{
auto
fragment
=
auto
fragment
=
makeGemmFragmentB
(
M
,
N
,
K
,
M
/
warp_m
,
N
/
warp_n
,
trans
_
B
);
makeGemmFragmentB
(
m_
,
n_
,
k_
,
m_
/
warp_m
,
n_
/
warp_n
,
transB
_
);
results
.
Set
(
B
,
fragment
->
BindThreadRange
(
thread_range
));
results
.
Set
(
b_
,
fragment
->
BindThreadRange
(
thread_range
));
}
else
{
}
else
{
ICHECK
(
0
);
ICHECK
(
0
);
}
}
...
@@ -880,18 +916,17 @@ TIR_REGISTER_TL_OP(Gemm, gemm)
...
@@ -880,18 +916,17 @@ TIR_REGISTER_TL_OP(Gemm, gemm)
TVM_REGISTER_OP
(
"tl.GemmWarpPolicy"
)
TVM_REGISTER_OP
(
"tl.GemmWarpPolicy"
)
.
set_attr
<
TScriptPrinterName
>
(
"TScriptPrinterName"
,
"GemmWarpPolicy"
);
.
set_attr
<
TScriptPrinterName
>
(
"TScriptPrinterName"
,
"GemmWarpPolicy"
);
TVM_FFI_STATIC_INIT_BLOCK
({
TVM_FFI_STATIC_INIT_BLOCK
(
)
{
GemmNode
::
RegisterReflection
();
GemmNode
::
RegisterReflection
();
GemmWarpPolicyNode
::
RegisterReflection
();
GemmWarpPolicyNode
::
RegisterReflection
();
namespace
refl
=
tvm
::
ffi
::
reflection
;
namespace
refl
=
tvm
::
ffi
::
reflection
;
refl
::
GlobalDef
().
def
(
"tl.GemmWarpPolicyComputeWarpPartition"
,
refl
::
GlobalDef
().
def
(
"tl.GemmWarpPolicyComputeWarpPartition"
,
[](
GemmWarpPolicy
policy
,
int
M
,
int
N
,
int
block_size
,
[](
GemmWarpPolicy
policy
,
int
M
,
int
N
,
int
block_size
,
Target
target
,
GemmInst
gemm_inst
)
{
Target
target
,
GemmInst
gemm_inst
)
{
policy
->
C
omputeWarpPartition
(
M
,
N
,
block_size
,
target
,
policy
->
c
omputeWarpPartition
(
M
,
N
,
block_size
,
target
,
gemm_inst
);
gemm_inst
);
return
;
});
});
}
);
}
}
// namespace tl
}
// namespace tl
}
// namespace tvm
}
// namespace tvm
src/op/gemm.h
View file @
bbbf4207
...
@@ -30,8 +30,7 @@ public:
...
@@ -30,8 +30,7 @@ public:
mutable
int
n_warp
{
0
};
mutable
int
n_warp
{
0
};
int
policy_type
;
int
policy_type
;
static
constexpr
const
char
*
_type_key
=
"tl.GemmWarpPolicy"
;
TVM_FFI_DECLARE_OBJECT_INFO
(
"tl.GemmWarpPolicy"
,
GemmWarpPolicyNode
,
Object
);
TVM_DECLARE_FINAL_OBJECT_INFO
(
GemmWarpPolicyNode
,
Object
);
static
void
RegisterReflection
()
{
static
void
RegisterReflection
()
{
namespace
refl
=
tvm
::
ffi
::
reflection
;
namespace
refl
=
tvm
::
ffi
::
reflection
;
...
@@ -41,22 +40,7 @@ public:
...
@@ -41,22 +40,7 @@ public:
.
def_ro
(
"n_warp"
,
&
GemmWarpPolicyNode
::
n_warp
);
.
def_ro
(
"n_warp"
,
&
GemmWarpPolicyNode
::
n_warp
);
}
}
bool
SEqualReduce
(
const
GemmWarpPolicyNode
*
other
,
std
::
pair
<
int
,
int
>
computeWarpPartition
(
int
M
,
int
N
,
int
block_size
,
SEqualReducer
equal
)
const
{
return
equal
(
policy_type
,
other
->
policy_type
)
&&
equal
(
m_warp
,
other
->
m_warp
)
&&
equal
(
n_warp
,
other
->
n_warp
);
}
void
SHashReduce
(
SHashReducer
hash_reduce
)
const
{
hash_reduce
(
policy_type
);
hash_reduce
(
m_warp
);
hash_reduce
(
n_warp
);
}
static
constexpr
bool
_type_has_method_sequal_reduce
=
true
;
static
constexpr
bool
_type_has_method_shash_reduce
=
true
;
std
::
pair
<
int
,
int
>
ComputeWarpPartition
(
int
M
,
int
N
,
int
block_size
,
Target
target
,
Target
target
,
GemmInst
gemm_inst
)
const
;
GemmInst
gemm_inst
)
const
;
...
@@ -74,22 +58,23 @@ public:
...
@@ -74,22 +58,23 @@ public:
class
GemmWarpPolicy
:
public
ObjectRef
{
class
GemmWarpPolicy
:
public
ObjectRef
{
public:
public:
TVM_DEFINE_OBJECT_REF_METHODS
(
GemmWarpPolicy
,
ObjectRef
,
GemmWarpPolicyNode
);
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE
(
GemmWarpPolicy
,
ObjectRef
,
GemmWarpPolicyNode
);
explicit
GemmWarpPolicy
(
GemmWarpPolicyType
policy_type
)
{
explicit
GemmWarpPolicy
(
GemmWarpPolicyType
policy_type
)
{
auto
node
=
make_object
<
GemmWarpPolicyNode
>
();
auto
node
=
tvm
::
ffi
::
make_object
<
GemmWarpPolicyNode
>
();
node
->
policy_type
=
(
int
)
policy_type
;
node
->
policy_type
=
(
int
)
policy_type
;
data_
=
std
::
move
(
node
);
data_
=
std
::
move
(
node
);
}
}
explicit
GemmWarpPolicy
(
int
policy_type
)
{
explicit
GemmWarpPolicy
(
int
policy_type
)
{
auto
node
=
make_object
<
GemmWarpPolicyNode
>
();
auto
node
=
tvm
::
ffi
::
make_object
<
GemmWarpPolicyNode
>
();
node
->
policy_type
=
policy_type
;
node
->
policy_type
=
policy_type
;
data_
=
std
::
move
(
node
);
data_
=
std
::
move
(
node
);
}
}
explicit
GemmWarpPolicy
(
int
m_warp
,
int
n_warp
)
{
explicit
GemmWarpPolicy
(
int
m_warp
,
int
n_warp
)
{
auto
node
=
make_object
<
GemmWarpPolicyNode
>
();
auto
node
=
tvm
::
ffi
::
make_object
<
GemmWarpPolicyNode
>
();
node
->
m_warp
=
m_warp
;
node
->
m_warp
=
m_warp
;
node
->
n_warp
=
n_warp
;
node
->
n_warp
=
n_warp
;
node
->
policy_type
=
(
int
)
GemmWarpPolicyType
::
kFree
;
node
->
policy_type
=
(
int
)
GemmWarpPolicyType
::
kFree
;
...
@@ -99,89 +84,48 @@ public:
...
@@ -99,89 +84,48 @@ public:
class
GemmNode
:
public
TileOperatorNode
{
class
GemmNode
:
public
TileOperatorNode
{
public:
public:
bool
C
heckW
GMMA
()
const
;
bool
c
heckW
gmma
()
const
;
tir
::
Buffer
A
,
B
,
C
;
tir
::
Buffer
a_
,
b_
,
c_
;
//
pointer to the
A, B
,
C
//
BufferRegion for
A, B
and
C
PrimExpr
Aptr
,
Bptr
,
Cptr
;
BufferRegion
aRegion_
,
bRegion_
,
cRegion_
;
bool
trans
_
A
,
trans
_
B
;
bool
transA
_
,
transB
_
;
int
M
,
N
,
K
;
int
m_
,
n_
,
k_
;
int
stride
_
A
,
stride
_
B
;
int
strideA
_
,
strideB
_
;
int
offset
_
A
,
offset
_
B
;
int
offsetA
_
,
offsetB
_
;
PrimExpr
clear
_a
ccum
=
const_false
();
PrimExpr
clear
A
ccum
_
=
const_false
();
// k_pack please ref to bitblas/tl/mfma_macro_generator.py::k_pack
// k_pack please ref to bitblas/tl/mfma_macro_generator.py::k_pack
// only will be enabled under cdna mfma instructions
// only will be enabled under cdna mfma instructions
int
kPack
=
1
;
int
kPack_
=
1
;
int
wg_wait
=
0
;
int
wgWait_
=
0
;
PrimExpr
mbarptr
;
PrimExpr
mbarPtr_
;
std
::
optional
<
tir
::
Buffer
>
mbar
;
// mbar is optional, only used for TCGEN5MMA
std
::
optional
<
tir
::
Buffer
>
mbar_
;
// mbar is optional, only used for TCGEN5MMA
Array
<
PrimExpr
>
C_coords
;
Array
<
PrimExpr
>
cCoords_
;
mutable
GemmWarpPolicy
policy
;
mutable
GemmWarpPolicy
policy_
;
TVM_FFI_DECLARE_OBJECT_INFO_FINAL
(
"tl.Gemm"
,
GemmNode
,
TileOperatorNode
);
static
constexpr
const
char
*
_type_key
=
"tl.Gemm"
;
TVM_DECLARE_FINAL_OBJECT_INFO
(
GemmNode
,
TileOperatorNode
);
static
void
RegisterReflection
()
{
static
void
RegisterReflection
()
{
namespace
refl
=
tvm
::
ffi
::
reflection
;
namespace
refl
=
tvm
::
ffi
::
reflection
;
refl
::
ObjectDef
<
GemmNode
>
()
refl
::
ObjectDef
<
GemmNode
>
()
.
def_ro
(
"A"
,
&
GemmNode
::
A
)
.
def_ro
(
"a"
,
&
GemmNode
::
a_
)
.
def_ro
(
"B"
,
&
GemmNode
::
B
)
.
def_ro
(
"b"
,
&
GemmNode
::
b_
)
.
def_ro
(
"C"
,
&
GemmNode
::
C
)
.
def_ro
(
"c"
,
&
GemmNode
::
c_
)
.
def_ro
(
"Aptr"
,
&
GemmNode
::
Aptr
)
.
def_ro
(
"aRegion"
,
&
GemmNode
::
aRegion_
)
.
def_ro
(
"Bptr"
,
&
GemmNode
::
Bptr
)
.
def_ro
(
"bRegion"
,
&
GemmNode
::
bRegion_
)
.
def_ro
(
"Cptr"
,
&
GemmNode
::
Cptr
)
.
def_ro
(
"cRegion"
,
&
GemmNode
::
cRegion_
)
.
def_ro
(
"trans_A"
,
&
GemmNode
::
trans_A
)
.
def_ro
(
"transA"
,
&
GemmNode
::
transA_
)
.
def_ro
(
"trans_B"
,
&
GemmNode
::
trans_B
)
.
def_ro
(
"transB"
,
&
GemmNode
::
transB_
)
.
def_ro
(
"M"
,
&
GemmNode
::
M
)
.
def_ro
(
"m"
,
&
GemmNode
::
m_
)
.
def_ro
(
"N"
,
&
GemmNode
::
N
)
.
def_ro
(
"n"
,
&
GemmNode
::
n_
)
.
def_ro
(
"K"
,
&
GemmNode
::
K
)
.
def_ro
(
"k"
,
&
GemmNode
::
k_
)
.
def_ro
(
"stride_A"
,
&
GemmNode
::
stride_A
)
.
def_ro
(
"strideA"
,
&
GemmNode
::
strideA_
)
.
def_ro
(
"stride_B"
,
&
GemmNode
::
stride_B
)
.
def_ro
(
"strideB"
,
&
GemmNode
::
strideB_
)
.
def_ro
(
"offset_A"
,
&
GemmNode
::
offset_A
)
.
def_ro
(
"offsetA"
,
&
GemmNode
::
offsetA_
)
.
def_ro
(
"offset_B"
,
&
GemmNode
::
offset_B
)
.
def_ro
(
"offsetB"
,
&
GemmNode
::
offsetB_
)
.
def_ro
(
"clear_accum"
,
&
GemmNode
::
clear_accum
)
.
def_ro
(
"clearAccum"
,
&
GemmNode
::
clearAccum_
)
.
def_ro
(
"kPack"
,
&
GemmNode
::
kPack
)
.
def_ro
(
"kPack"
,
&
GemmNode
::
kPack_
)
.
def_ro
(
"wg_wait"
,
&
GemmNode
::
wg_wait
)
.
def_ro
(
"wgWait"
,
&
GemmNode
::
wgWait_
)
.
def_ro
(
"policy"
,
&
GemmNode
::
policy
);
.
def_ro
(
"policy"
,
&
GemmNode
::
policy_
);
}
bool
SEqualReduce
(
const
GemmNode
*
other
,
SEqualReducer
equal
)
const
{
return
equal
(
A
,
other
->
A
)
&&
equal
(
B
,
other
->
B
)
&&
equal
(
C
,
other
->
C
)
&&
equal
(
Aptr
,
other
->
Aptr
)
&&
equal
(
Bptr
,
other
->
Bptr
)
&&
equal
(
Cptr
,
other
->
Cptr
)
&&
equal
(
trans_A
,
other
->
trans_A
)
&&
equal
(
trans_B
,
other
->
trans_B
)
&&
equal
(
M
,
other
->
M
)
&&
equal
(
N
,
other
->
N
)
&&
equal
(
K
,
other
->
K
)
&&
equal
(
stride_A
,
other
->
stride_A
)
&&
equal
(
stride_B
,
other
->
stride_B
)
&&
equal
(
offset_A
,
other
->
offset_A
)
&&
equal
(
offset_B
,
other
->
offset_B
)
&&
equal
(
clear_accum
,
other
->
clear_accum
)
&&
equal
(
kPack
,
other
->
kPack
)
&&
equal
(
wg_wait
,
other
->
wg_wait
)
&&
equal
(
policy
,
other
->
policy
);
}
void
SHashReduce
(
SHashReducer
hash_reduce
)
const
{
hash_reduce
(
A
);
hash_reduce
(
B
);
hash_reduce
(
C
);
hash_reduce
(
Aptr
);
hash_reduce
(
Bptr
);
hash_reduce
(
Cptr
);
hash_reduce
(
trans_A
);
hash_reduce
(
trans_B
);
hash_reduce
(
M
);
hash_reduce
(
N
);
hash_reduce
(
K
);
hash_reduce
(
stride_A
);
hash_reduce
(
stride_B
);
hash_reduce
(
offset_A
);
hash_reduce
(
offset_B
);
hash_reduce
(
clear_accum
);
hash_reduce
(
kPack
);
hash_reduce
(
wg_wait
);
hash_reduce
(
policy
);
}
}
static
constexpr
bool
_type_has_method_sequal_reduce
=
true
;
static
constexpr
bool
_type_has_method_shash_reduce
=
true
;
Stmt
Lower
(
const
LowerArgs
&
T
,
arith
::
Analyzer
*
analyzer
)
const
override
;
Stmt
Lower
(
const
LowerArgs
&
T
,
arith
::
Analyzer
*
analyzer
)
const
override
;
LayoutMap
InferLayout
(
const
LayoutInferArgs
&
T
,
LayoutMap
InferLayout
(
const
LayoutInferArgs
&
T
,
...
@@ -190,16 +134,16 @@ public:
...
@@ -190,16 +134,16 @@ public:
TileOperator
Clone
()
const
;
TileOperator
Clone
()
const
;
private:
private:
GemmInst
G
etGemmInst
(
int
block_size
,
Target
target
)
const
;
GemmInst
g
etGemmInst
(
int
block_size
,
Target
target
)
const
;
bool
A
llowT
CGEN5MMA
(
Target
target
)
const
;
bool
a
llowT
cgen5Mma
(
Target
target
)
const
;
bool
A
llowW
GMMA
(
int
block_size
,
Target
target
)
const
;
bool
a
llowW
gmma
(
int
block_size
,
Target
target
)
const
;
mutable
bool
completed_
=
false
;
mutable
bool
completed_
=
false
;
};
};
class
Gemm
:
public
TileOperator
{
class
Gemm
:
public
TileOperator
{
public:
public:
TVM_DEFINE_OBJECT_REF_METHODS
(
Gemm
,
TileOperator
,
GemmNode
);
TVM_
FFI_
DEFINE_OBJECT_REF_METHODS
_NULLABLE
(
Gemm
,
TileOperator
,
GemmNode
);
TVM_DLL
Gemm
(
Array
<
PrimExpr
>
args
,
BufferMap
vmap
);
TVM_DLL
Gemm
(
Array
<
PrimExpr
>
args
,
BufferMap
vmap
);
static
const
Op
&
Get
();
static
const
Op
&
Get
();
};
};
...
@@ -207,4 +151,4 @@ public:
...
@@ -207,4 +151,4 @@ public:
}
// namespace tl
}
// namespace tl
}
// namespace tvm
}
// namespace tvm
#endif // TVM_TL_OP_GEMM_H_
#endif // TVM_TL_OP_GEMM_H_
\ No newline at end of file
src/op/gemm_py.cc
View file @
bbbf4207
...
@@ -12,13 +12,101 @@
...
@@ -12,13 +12,101 @@
#include <tvm/tir/transform.h>
#include <tvm/tir/transform.h>
#include "../target/utils.h"
#include "../target/utils.h"
#include "tvm/ffi/string.h"
#include "region.h"
#include "tcgen5_meta.h"
namespace
tvm
{
namespace
tvm
{
namespace
tl
{
namespace
tl
{
using
namespace
tir
;
using
namespace
tir
;
// Normalize a GEMM argument (BufferRegion/BufferLoad/tvm_access_ptr/tl.region)
// to BufferRegion
static
BufferRegion
NormalizeToBufferRegion
(
const
PrimExpr
&
arg
,
const
BufferMap
&
vmap
)
{
// Case 1: Already a BufferRegion
if
(
arg
->
IsInstance
<
BufferRegionNode
>
())
{
return
Downcast
<
BufferRegion
>
(
arg
);
}
// Case 2: BufferLoad — convert indices to ranges (Ramp -> lanes, else
// extent=1)
if
(
const
auto
*
load
=
arg
.
as
<
BufferLoadNode
>
())
{
Array
<
Range
>
ranges
;
for
(
const
PrimExpr
&
index
:
load
->
indices
)
{
if
(
const
auto
*
ramp
=
index
.
as
<
RampNode
>
())
{
ICHECK
(
ramp
->
stride
.
as
<
IntImmNode
>
())
<<
"Ramp stride must be IntImm"
;
ICHECK_EQ
(
ramp
->
stride
.
as
<
IntImmNode
>
()
->
value
,
1
)
<<
"Only stride-1 Ramp is supported in GEMM region conversion"
;
ICHECK
(
ramp
->
lanes
.
as
<
IntImmNode
>
())
<<
"Scalable vector lanes not supported in GEMM region conversion"
;
ranges
.
push_back
(
Range
::
FromMinExtent
(
ramp
->
base
,
ramp
->
lanes
));
}
else
{
ranges
.
push_back
(
Range
::
FromMinExtent
(
index
,
1
));
}
}
return
BufferRegion
(
load
->
buffer
,
ranges
);
}
// Case 3: Call nodes
if
(
const
auto
*
call
=
arg
.
as
<
CallNode
>
())
{
// tl.region(...) — reconstruct via RegionOp
if
(
call
->
op
.
same_as
(
RegionOp
::
Get
()))
{
RegionOp
region
(
call
->
args
,
vmap
);
return
BufferRegion
(
region
->
GetBuffer
(),
region
->
GetRanges
());
}
// builtin.tvm_access_ptr(...) — map var to Buffer and take full region
if
(
call
->
op
.
same_as
(
builtin
::
tvm_access_ptr
()))
{
Var
var
=
Downcast
<
Var
>
(
call
->
args
[
1
]);
Buffer
buf
=
vmap
.
at
(
var
);
Array
<
Range
>
ranges
;
for
(
PrimExpr
extent
:
buf
->
shape
)
{
ranges
.
push_back
(
Range
(
IntImm
(
extent
->
dtype
,
0
),
extent
));
}
return
BufferRegion
(
buf
,
ranges
);
}
}
LOG
(
FATAL
)
<<
"Unsupported GEMM argument for BufferRegion: "
<<
arg
;
throw
;
// Unreachable, keeps compiler happy
}
// Build a tvm_access_ptr(handle) to the start of the 2D tile within a
// BufferRegion. Offset is computed from all but the last two dimensions; extent
// is the product of the last two extents. rw_mask: 1=read, 2=write,
// 3=readwrite.
static
PrimExpr
MakeAccessPtrFromRegion
(
const
BufferRegion
&
region
,
int
rw_mask
)
{
Buffer
buf
=
region
->
buffer
;
int
ndim
=
static_cast
<
int
>
(
buf
->
shape
.
size
());
ICHECK
(
ndim
>=
2
)
<<
"GEMM expects buffers with at least 2 dims"
;
// Compute row-major strides
std
::
vector
<
PrimExpr
>
strides
(
ndim
);
PrimExpr
one
=
make_const
(
buf
->
shape
[
0
].
dtype
(),
1
);
PrimExpr
cur
=
one
;
for
(
int
i
=
ndim
-
1
;
i
>=
0
;
--
i
)
{
strides
[
i
]
=
cur
;
cur
=
cur
*
buf
->
shape
[
i
];
}
// Offset: sum_{i in [0..ndim-3]} min_i * stride_i
PrimExpr
offset
=
make_const
(
buf
->
shape
[
0
].
dtype
(),
0
);
for
(
int
i
=
0
;
i
<
ndim
-
2
;
++
i
)
{
offset
=
offset
+
region
->
region
[
i
]
->
min
*
strides
[
i
];
}
// Extent: last two extents product (elements)
PrimExpr
extent
=
region
->
region
[
ndim
-
2
]
->
extent
*
region
->
region
[
ndim
-
1
]
->
extent
;
// ptype and return handle
PrimExpr
ptype
=
tir
::
TypeAnnotation
(
buf
->
dtype
);
Array
<
PrimExpr
>
acc_args
{
ptype
,
buf
->
data
,
offset
,
extent
,
IntImm
(
DataType
::
Int
(
32
),
rw_mask
)};
return
Call
(
DataType
::
Handle
(),
builtin
::
tvm_access_ptr
(),
acc_args
);
}
/**
/**
* @brief Construct a Gemm operator from serialized TL arguments and a buffer
* @brief Construct a Gemm operator from serialized TL arguments and a buffer
* map.
* map.
...
@@ -48,34 +136,43 @@ using namespace tir;
...
@@ -48,34 +136,43 @@ using namespace tir;
* performed here.
* performed here.
*/
*/
GemmPy
::
GemmPy
(
Array
<
PrimExpr
>
args
,
BufferMap
vmap
)
{
GemmPy
::
GemmPy
(
Array
<
PrimExpr
>
args
,
BufferMap
vmap
)
{
ObjectPtr
<
GemmPyNode
>
node
=
make_object
<
GemmPyNode
>
();
ObjectPtr
<
GemmPyNode
>
node
=
tvm
::
ffi
::
make_object
<
GemmPyNode
>
();
node
->
Aptr
=
args
[
0
];
node
->
aRegion_
=
NormalizeToBufferRegion
(
args
[
0
],
vmap
);
node
->
Bptr
=
args
[
1
];
node
->
bRegion_
=
NormalizeToBufferRegion
(
args
[
1
],
vmap
);
node
->
Cptr
=
args
[
2
];
node
->
cRegion_
=
NormalizeToBufferRegion
(
args
[
2
],
vmap
);
node
->
A
=
vmap
[
GetVarFromAccessPtr
(
node
->
Aptr
)];
node
->
B
=
vmap
[
GetVarFromAccessPtr
(
node
->
Bptr
)];
node
->
a_
=
node
->
aRegion_
->
buffer
;
node
->
C
=
vmap
[
GetVarFromAccessPtr
(
node
->
Cptr
)];
node
->
b_
=
node
->
bRegion_
->
buffer
;
node
->
trans_A
=
args
[
3
].
as
<
Bool
>
().
value
();
node
->
c_
=
node
->
cRegion_
->
buffer
;
node
->
trans_B
=
args
[
4
].
as
<
Bool
>
().
value
();
node
->
transA_
=
args
[
3
].
as
<
Bool
>
().
value
();
node
->
M
=
args
[
5
].
as
<
IntImm
>
().
value
()
->
value
;
node
->
transB_
=
args
[
4
].
as
<
Bool
>
().
value
();
node
->
N
=
args
[
6
].
as
<
IntImm
>
().
value
()
->
value
;
node
->
m_
=
args
[
5
].
as
<
IntImm
>
().
value
()
->
value
;
node
->
K
=
args
[
7
].
as
<
IntImm
>
().
value
()
->
value
;
node
->
n_
=
args
[
6
].
as
<
IntImm
>
().
value
()
->
value
;
node
->
policy
=
GemmWarpPolicy
(
args
[
8
].
as
<
IntImm
>
().
value
()
->
value
);
node
->
k_
=
args
[
7
].
as
<
IntImm
>
().
value
()
->
value
;
node
->
clear_accum
=
args
[
9
].
as
<
PrimExpr
>
().
value
();
node
->
policy_
=
GemmWarpPolicy
(
args
[
8
].
as
<
IntImm
>
().
value
()
->
value
);
node
->
stride_A
=
args
[
10
].
as
<
IntImm
>
().
value
()
->
value
;
node
->
clearAccum_
=
args
[
9
].
as
<
PrimExpr
>
().
value
();
node
->
stride_B
=
args
[
11
].
as
<
IntImm
>
().
value
()
->
value
;
node
->
strideA_
=
args
[
10
].
as
<
IntImm
>
().
value
()
->
value
;
node
->
offset_A
=
args
[
12
].
as
<
IntImm
>
().
value
()
->
value
;
node
->
strideB_
=
args
[
11
].
as
<
IntImm
>
().
value
()
->
value
;
node
->
offset_B
=
args
[
13
].
as
<
IntImm
>
().
value
()
->
value
;
node
->
offsetA_
=
args
[
12
].
as
<
IntImm
>
().
value
()
->
value
;
node
->
offsetB_
=
args
[
13
].
as
<
IntImm
>
().
value
()
->
value
;
if
(
args
.
size
()
>
14
)
{
if
(
args
.
size
()
>
14
)
{
node
->
kPack
=
args
[
14
].
as
<
IntImm
>
().
value
()
->
value
;
node
->
kPack
_
=
args
[
14
].
as
<
IntImm
>
().
value
()
->
value
;
if
(
node
->
kPack
!=
1
&&
node
->
kPack
!=
2
)
{
if
(
node
->
kPack
_
!=
1
&&
node
->
kPack
_
!=
2
)
{
ICHECK
(
false
)
<<
"kPack must be 1 or 2"
;
ICHECK
(
false
)
<<
"kPack must be 1 or 2"
;
}
}
}
}
if
(
args
.
size
()
>
15
)
{
if
(
args
.
size
()
>
15
)
{
node
->
wg
_w
ait
=
args
[
15
].
as
<
IntImm
>
().
value
()
->
value
;
node
->
wg
W
ait
_
=
args
[
15
].
as
<
IntImm
>
().
value
()
->
value
;
}
}
node
->
mbarPtr_
=
args
[
16
];
if
(
node
->
mbarPtr_
.
as
<
CallNode
>
())
{
node
->
mbar_
=
vmap
[
GetVarFromAccessPtr
(
node
->
mbarPtr_
)];
}
else
{
node
->
mbar_
=
std
::
nullopt
;
}
node
->
cCoords_
=
Array
<
PrimExpr
>
(
{
args
[
17
].
as
<
PrimExpr
>
().
value
(),
args
[
18
].
as
<
PrimExpr
>
().
value
()});
data_
=
std
::
move
(
node
);
data_
=
std
::
move
(
node
);
}
}
...
@@ -88,20 +185,41 @@ GemmPy::GemmPy(Array<PrimExpr> args, BufferMap vmap) {
...
@@ -88,20 +185,41 @@ GemmPy::GemmPy(Array<PrimExpr> args, BufferMap vmap) {
* @return TileOperator A Gemm operator that owns a copy of this node.
* @return TileOperator A Gemm operator that owns a copy of this node.
*/
*/
TileOperator
GemmPyNode
::
Clone
()
const
{
TileOperator
GemmPyNode
::
Clone
()
const
{
auto
op
=
make_object
<
GemmPyNode
>
(
*
this
);
auto
op
=
tvm
::
ffi
::
make_object
<
GemmPyNode
>
(
*
this
);
return
GemmPy
(
op
);
return
GemmPy
(
op
);
}
}
GemmInst
GemmPyNode
::
GetGemmInst
(
int
block_size
,
Target
target
)
const
{
bool
GemmPyNode
::
allowTcgen5Mma
(
Target
target
)
const
{
return
TargetIsSm100
(
target
)
&&
((
a_
.
scope
()
==
"shared.dyn"
||
a_
.
scope
()
==
"shared"
||
a_
.
scope
()
==
"shared.tmem"
)
&&
(
b_
.
scope
()
==
"shared.dyn"
||
b_
.
scope
()
==
"shared"
)
&&
c_
.
scope
()
==
"shared.tmem"
)
&&
GetTCGEN5MMAMeta
(
m_
,
n_
,
k_
,
a_
->
dtype
,
c_
->
dtype
).
first
;
}
bool
GemmPyNode
::
allowWgmma
(
int
block_size
,
Target
target
)
const
{
tvm
::
transform
::
PassContext
ctxt
=
tvm
::
transform
::
PassContext
::
Current
();
int
warp_size
=
TargetGetWarpSize
(
target
);
int
warp_size
=
TargetGetWarpSize
(
target
);
int
num_warps
=
block_size
/
warp_size
;
int
num_warps
=
block_size
/
warp_size
;
bool
allow_wgmma
=
TargetIsHopper
(
target
)
&&
(
this
->
M
>=
64
)
&&
return
!
ctxt
->
GetConfig
(
kDisableWGMMA
,
Optional
<
Bool
>
()).
value_or
(
false
)
&&
(
num_warps
%
4
==
0
)
&&
CheckWGMMA
();
TargetIsHopper
(
target
)
&&
(
this
->
m_
>=
64
)
&&
(
num_warps
%
4
==
0
)
&&
if
(
allow_wgmma
)
{
checkWgmma
();
}
GemmInst
GemmPyNode
::
getGemmInst
(
int
block_size
,
Target
target
)
const
{
bool
allow_tcgen5mma
=
allowTcgen5Mma
(
target
);
bool
allow_wgmma
=
allowWgmma
(
block_size
,
target
);
if
(
allow_tcgen5mma
)
{
return
GemmInst
::
kTCGEN5MMA
;
}
else
if
(
allow_wgmma
)
{
return
GemmInst
::
kWGMMA
;
return
GemmInst
::
kWGMMA
;
}
else
if
(
TargetIsCDNA
(
target
))
{
}
else
if
(
TargetIsCDNA
(
target
))
{
return
GemmInst
::
kMFMA
;
return
GemmInst
::
kMFMA
;
}
else
if
(
TargetIsCuda
(
target
))
{
}
else
if
(
TargetIsVolta
(
target
)
||
TargetIsAmpere
(
target
)
||
TargetIsTuring
(
target
)
||
TargetIsHopper
(
target
)
||
TargetIsSm100
(
target
))
{
return
GemmInst
::
kMMA
;
return
GemmInst
::
kMMA
;
}
else
{
}
else
{
ICHECK
(
0
)
<<
"Unsupported target for gemm: "
<<
target
->
str
();
ICHECK
(
0
)
<<
"Unsupported target for gemm: "
<<
target
->
str
();
...
@@ -140,51 +258,52 @@ GemmInst GemmPyNode::GetGemmInst(int block_size, Target target) const {
...
@@ -140,51 +258,52 @@ GemmInst GemmPyNode::GetGemmInst(int block_size, Target target) const {
* @return true if WGMMA is supported for the current buffers, dtypes, and
* @return true if WGMMA is supported for the current buffers, dtypes, and
* transpose/shape constraints; false otherwise.
* transpose/shape constraints; false otherwise.
*/
*/
bool
GemmPyNode
::
C
heckW
GMMA
()
const
{
bool
GemmPyNode
::
c
heckW
gmma
()
const
{
if
(
B
.
scope
()
!=
"shared.dyn"
&&
B
.
scope
()
!=
"shared"
)
{
if
(
b_
.
scope
()
!=
"shared.dyn"
&&
b_
.
scope
()
!=
"shared"
)
{
return
false
;
return
false
;
}
}
if
(
C
->
dtype
==
DataType
::
Float
(
16
))
{
if
(
c_
->
dtype
==
DataType
::
Float
(
16
))
{
if
(
A
->
dtype
==
DataType
::
Float
(
16
)
&&
B
->
dtype
==
DataType
::
Float
(
16
))
if
(
a_
->
dtype
==
DataType
::
Float
(
16
)
&&
b_
->
dtype
==
DataType
::
Float
(
16
))
return
K
%
16
==
0
;
return
k_
%
16
==
0
;
else
if
(
A
->
dtype
.
is_float8_e4m3
()
&&
B
->
dtype
.
is_float8_e4m3
())
else
if
(
a_
->
dtype
.
is_float8_e4m3
()
&&
b_
->
dtype
.
is_float8_e4m3
())
return
(
!
trans
_
A
)
&&
trans
_
B
&&
K
%
32
==
0
;
return
(
!
transA
_
)
&&
transB
_
&&
k_
%
32
==
0
;
else
if
(
A
->
dtype
.
is_float8_e4m3
()
&&
B
->
dtype
.
is_float8_e5m2
())
else
if
(
a_
->
dtype
.
is_float8_e4m3
()
&&
b_
->
dtype
.
is_float8_e5m2
())
return
(
!
trans
_
A
)
&&
trans
_
B
&&
K
%
32
==
0
;
return
(
!
transA
_
)
&&
transB
_
&&
k_
%
32
==
0
;
else
if
(
A
->
dtype
.
is_float8_e5m2
()
&&
B
->
dtype
.
is_float8_e4m3
())
else
if
(
a_
->
dtype
.
is_float8_e5m2
()
&&
b_
->
dtype
.
is_float8_e4m3
())
return
(
!
trans
_
A
)
&&
trans
_
B
&&
K
%
32
==
0
;
return
(
!
transA
_
)
&&
transB
_
&&
k_
%
32
==
0
;
else
if
(
A
->
dtype
.
is_float8_e5m2
()
&&
B
->
dtype
.
is_float8_e5m2
())
else
if
(
a_
->
dtype
.
is_float8_e5m2
()
&&
b_
->
dtype
.
is_float8_e5m2
())
return
(
!
trans
_
A
)
&&
trans
_
B
&&
K
%
32
==
0
;
return
(
!
transA
_
)
&&
transB
_
&&
k_
%
32
==
0
;
else
else
return
false
;
return
false
;
}
else
if
(
C
->
dtype
==
DataType
::
Float
(
32
))
{
}
else
if
(
c_
->
dtype
==
DataType
::
Float
(
32
))
{
if
(
A
->
dtype
==
DataType
::
Float
(
16
)
&&
B
->
dtype
==
DataType
::
Float
(
16
))
if
(
a_
->
dtype
==
DataType
::
Float
(
16
)
&&
b_
->
dtype
==
DataType
::
Float
(
16
))
return
K
%
16
==
0
;
return
k_
%
16
==
0
;
else
if
(
A
->
dtype
==
DataType
::
BFloat
(
16
)
&&
else
if
(
a_
->
dtype
==
DataType
::
BFloat
(
16
)
&&
B
->
dtype
==
DataType
::
BFloat
(
16
))
b_
->
dtype
==
DataType
::
BFloat
(
16
))
return
K
%
16
==
0
;
return
k_
%
16
==
0
;
else
if
(
A
->
dtype
==
DataType
::
Float
(
32
)
&&
B
->
dtype
==
DataType
::
Float
(
32
))
else
if
(
a_
->
dtype
==
DataType
::
Float
(
32
)
&&
return
(
!
trans_A
)
&&
trans_B
&&
K
%
8
==
0
;
b_
->
dtype
==
DataType
::
Float
(
32
))
else
if
(
A
->
dtype
.
is_float8_e4m3
()
&&
B
->
dtype
.
is_float8_e4m3
())
return
(
!
transA_
)
&&
transB_
&&
k_
%
8
==
0
;
return
(
!
trans_A
)
&&
trans_B
&&
K
%
32
==
0
;
else
if
(
a_
->
dtype
.
is_float8_e4m3
()
&&
b_
->
dtype
.
is_float8_e4m3
())
else
if
(
A
->
dtype
.
is_float8_e4m3
()
&&
B
->
dtype
.
is_float8_e5m2
())
return
(
!
transA_
)
&&
transB_
&&
k_
%
32
==
0
;
return
(
!
trans_A
)
&&
trans_B
&&
K
%
32
==
0
;
else
if
(
a_
->
dtype
.
is_float8_e4m3
()
&&
b_
->
dtype
.
is_float8_e5m2
())
else
if
(
A
->
dtype
.
is_float8_e5m2
()
&&
B
->
dtype
.
is_float8_e4m3
())
return
(
!
transA_
)
&&
transB_
&&
k_
%
32
==
0
;
return
(
!
trans_A
)
&&
trans_B
&&
K
%
32
==
0
;
else
if
(
a_
->
dtype
.
is_float8_e5m2
()
&&
b_
->
dtype
.
is_float8_e4m3
())
else
if
(
A
->
dtype
.
is_float8_e5m2
()
&&
B
->
dtype
.
is_float8_e5m2
())
return
(
!
transA_
)
&&
transB_
&&
k_
%
32
==
0
;
return
(
!
trans_A
)
&&
trans_B
&&
K
%
32
==
0
;
else
if
(
a_
->
dtype
.
is_float8_e5m2
()
&&
b_
->
dtype
.
is_float8_e5m2
())
return
(
!
transA_
)
&&
transB_
&&
k_
%
32
==
0
;
else
else
return
false
;
return
false
;
}
else
if
(
C
->
dtype
==
DataType
::
Int
(
32
))
{
}
else
if
(
c_
->
dtype
==
DataType
::
Int
(
32
))
{
if
(
A
->
dtype
==
DataType
::
Int
(
8
)
&&
B
->
dtype
==
DataType
::
Int
(
8
))
if
(
a_
->
dtype
==
DataType
::
Int
(
8
)
&&
b_
->
dtype
==
DataType
::
Int
(
8
))
return
(
!
trans
_
A
)
&&
trans
_
B
&&
K
%
32
==
0
;
return
(
!
transA
_
)
&&
transB
_
&&
k_
%
32
==
0
;
else
if
(
A
->
dtype
==
DataType
::
Int
(
8
)
&&
B
->
dtype
==
DataType
::
UInt
(
8
))
else
if
(
a_
->
dtype
==
DataType
::
Int
(
8
)
&&
b_
->
dtype
==
DataType
::
UInt
(
8
))
return
(
!
trans
_
A
)
&&
trans
_
B
&&
K
%
32
==
0
;
return
(
!
transA
_
)
&&
transB
_
&&
k_
%
32
==
0
;
else
if
(
A
->
dtype
==
DataType
::
UInt
(
8
)
&&
B
->
dtype
==
DataType
::
Int
(
8
))
else
if
(
a_
->
dtype
==
DataType
::
UInt
(
8
)
&&
b_
->
dtype
==
DataType
::
Int
(
8
))
return
(
!
trans
_
A
)
&&
trans
_
B
&&
K
%
32
==
0
;
return
(
!
transA
_
)
&&
transB
_
&&
k_
%
32
==
0
;
else
if
(
A
->
dtype
==
DataType
::
UInt
(
8
)
&&
B
->
dtype
==
DataType
::
UInt
(
8
))
else
if
(
a_
->
dtype
==
DataType
::
UInt
(
8
)
&&
b_
->
dtype
==
DataType
::
UInt
(
8
))
return
(
!
trans
_
A
)
&&
trans
_
B
&&
K
%
32
==
0
;
return
(
!
transA
_
)
&&
transB
_
&&
k_
%
32
==
0
;
else
else
return
false
;
return
false
;
}
else
{
}
else
{
...
@@ -208,8 +327,8 @@ bool GemmPyNode::CheckWGMMA() const {
...
@@ -208,8 +327,8 @@ bool GemmPyNode::CheckWGMMA() const {
*/
*/
static
int
GetArchInt
(
Target
target
)
{
static
int
GetArchInt
(
Target
target
)
{
int
arch_int
=
0
;
int
arch_int
=
0
;
auto
s
=
target
->
GetAttr
<
String
>
(
"arch"
);
auto
s
=
target
->
GetAttr
<
tvm
::
ffi
::
String
>
(
"arch"
);
ICHECK
(
s
.
defined
());
ICHECK
(
s
.
has_value
());
std
::
string
arch
=
s
.
value
();
std
::
string
arch
=
s
.
value
();
if
(
arch
.
rfind
(
"sm_"
,
0
)
==
0
)
{
if
(
arch
.
rfind
(
"sm_"
,
0
)
==
0
)
{
arch_int
=
std
::
stoi
(
arch
.
substr
(
3
));
arch_int
=
std
::
stoi
(
arch
.
substr
(
3
));
...
@@ -221,18 +340,19 @@ static int GetArchInt(Target target) {
...
@@ -221,18 +340,19 @@ static int GetArchInt(Target target) {
Stmt
GemmPyNode
::
Lower
(
const
LowerArgs
&
T
,
arith
::
Analyzer
*
analyzer
)
const
{
Stmt
GemmPyNode
::
Lower
(
const
LowerArgs
&
T
,
arith
::
Analyzer
*
analyzer
)
const
{
auto
block_size
=
*
as_const_int
(
T
.
thread_bounds
->
extent
);
auto
block_size
=
*
as_const_int
(
T
.
thread_bounds
->
extent
);
GemmInst
gemm_inst
=
G
etGemmInst
(
block_size
,
T
.
target
);
GemmInst
gemm_inst
=
g
etGemmInst
(
block_size
,
T
.
target
);
auto
[
warp_m
,
warp_n
]
=
auto
[
warp_m
,
warp_n
]
=
policy
->
C
omputeWarpPartition
(
M
,
N
,
block_size
,
T
.
target
,
gemm_inst
);
policy
_
->
c
omputeWarpPartition
(
m_
,
n_
,
block_size
,
T
.
target
,
gemm_inst
);
if
(
const
auto
f
=
ffi
::
Function
::
GetGlobal
(
"tl.gemm_py.lower"
))
{
if
(
const
auto
f
=
ffi
::
Function
::
GetGlobal
(
"tl.gemm_py.lower"
))
{
auto
prim_func
=
auto
prim_func
=
Downcast
<
PrimFunc
>
((
*
f
)(
GetRef
<
GemmPy
>
(
this
),
T
.
layout_map
,
T
.
target
,
Downcast
<
PrimFunc
>
((
*
f
)(
tvm
::
ffi
::
GetRef
<
GemmPy
>
(
this
),
T
.
layout_map
,
T
.
thread_bounds
,
T
.
thread_var
));
T
.
target
,
T
.
thread_bounds
,
T
.
thread_var
));
ICHECK
(
prim_func
->
attrs
.
defined
());
ICHECK
(
prim_func
->
attrs
.
defined
());
auto
global_symbol
=
prim_func
->
attrs
.
GetAttr
<
String
>
(
"global_symbol"
);
auto
global_symbol
=
ICHECK
(
global_symbol
.
defined
());
prim_func
->
attrs
.
GetAttr
<
tvm
::
ffi
::
String
>
(
"global_symbol"
);
ICHECK
(
global_symbol
.
has_value
());
if
(
prim_func
->
body
.
as
<
BlockRealizeNode
>
())
{
if
(
prim_func
->
body
.
as
<
BlockRealizeNode
>
())
{
BlockRealize
block_realize
=
Downcast
<
BlockRealize
>
(
prim_func
->
body
);
BlockRealize
block_realize
=
Downcast
<
BlockRealize
>
(
prim_func
->
body
);
auto
block
=
block_realize
->
block
;
auto
block
=
block_realize
->
block
;
...
@@ -265,7 +385,15 @@ LayoutMap GemmPyNode::InferLayout(const LayoutInferArgs &T,
...
@@ -265,7 +385,15 @@ LayoutMap GemmPyNode::InferLayout(const LayoutInferArgs &T,
if
(
const
auto
f
=
ffi
::
Function
::
GetGlobal
(
"tl.gemm_py.infer_layout"
))
{
if
(
const
auto
f
=
ffi
::
Function
::
GetGlobal
(
"tl.gemm_py.infer_layout"
))
{
results
=
Downcast
<
LayoutMap
>
(
results
=
Downcast
<
LayoutMap
>
(
(
*
f
)(
GetRef
<
GemmPy
>
(
this
),
T
.
target
,
T
.
thread_bounds
));
(
*
f
)(
tvm
::
ffi
::
GetRef
<
GemmPy
>
(
this
),
T
.
target
,
T
.
thread_bounds
));
// Bind all fragment layouts with the provided thread range
for
(
auto
kv
:
results
)
{
const
Buffer
&
buf
=
kv
.
first
;
const
Layout
&
layout
=
kv
.
second
;
if
(
auto
frag
=
layout
.
as
<
Fragment
>
())
{
results
.
Set
(
buf
,
frag
.
value
()
->
BindThreadRange
(
T
.
thread_bounds
));
}
}
}
else
{
}
else
{
LOG
(
FATAL
)
<<
"No infer layout function found for gemm_py"
;
LOG
(
FATAL
)
<<
"No infer layout function found for gemm_py"
;
}
}
...
@@ -279,15 +407,41 @@ TIR_REGISTER_TL_OP(GemmPy, gemm_py)
...
@@ -279,15 +407,41 @@ TIR_REGISTER_TL_OP(GemmPy, gemm_py)
.
set_attr
<
TCallEffectKind
>
(
"TCallEffectKind"
,
.
set_attr
<
TCallEffectKind
>
(
"TCallEffectKind"
,
Integer
(
CallEffectKind
::
kOpaque
));
Integer
(
CallEffectKind
::
kOpaque
));
TVM_FFI_STATIC_INIT_BLOCK
({
GemmPyNode
::
RegisterReflection
();
}
);
TVM_FFI_STATIC_INIT_BLOCK
(
)
{
GemmPyNode
::
RegisterReflection
();
}
TVM_FFI_STATIC_INIT_BLOCK
({
TVM_FFI_STATIC_INIT_BLOCK
(
)
{
namespace
refl
=
tvm
::
ffi
::
reflection
;
namespace
refl
=
tvm
::
ffi
::
reflection
;
refl
::
GlobalDef
().
def
(
"tl.GemmPyGemmInst"
,
refl
::
GlobalDef
().
def
(
"tl.GemmPyGemmInst"
,
[](
GemmPy
gemm_py
,
int
block_size
,
Target
target
)
{
[](
GemmPy
gemm_py
,
int
block_size
,
Target
target
)
{
return
gemm_py
->
G
etGemmInst
(
block_size
,
target
);
return
gemm_py
->
g
etGemmInst
(
block_size
,
target
);
});
});
});
}
TVM_FFI_STATIC_INIT_BLOCK
()
{
namespace
refl
=
tvm
::
ffi
::
reflection
;
refl
::
GlobalDef
().
def
(
"tl.get_tcgen5_mma_meta"
,
[](
int
M
,
int
N
,
int
K
,
DataType
ab_dtype
,
DataType
c_dtype
)
{
auto
[
success
,
meta
]
=
GetTCGEN5MMAMeta
(
M
,
N
,
K
,
ab_dtype
,
c_dtype
);
Array
<
Integer
>
result
;
if
(
success
)
{
result
.
push_back
(
Integer
(
meta
.
atom_m
));
result
.
push_back
(
Integer
(
meta
.
atom_n
));
result
.
push_back
(
Integer
(
meta
.
atom_k
));
}
return
result
;
});
refl
::
GlobalDef
().
def
(
"tl.get_tcgen5_instr_desc"
,
[](
int
atom_m
,
int
atom_n
,
int
atom_k
,
DataType
ab_dtype
,
DataType
c_dtype
,
bool
a_is_k_major
,
bool
b_is_k_major
,
int
scale_in_a
,
int
scale_in_b
)
{
uint32_t
desc
=
GetTCGEN5InstrDesc
(
atom_m
,
atom_n
,
atom_k
,
ab_dtype
,
c_dtype
,
a_is_k_major
,
b_is_k_major
,
scale_in_a
,
scale_in_b
);
return
Integer
(
static_cast
<
int64_t
>
(
desc
));
});
}
}
// namespace tl
}
// namespace tl
}
// namespace tvm
}
// namespace tvm
src/op/gemm_py.h
View file @
bbbf4207
...
@@ -18,87 +18,54 @@ using namespace tir;
...
@@ -18,87 +18,54 @@ using namespace tir;
class
GemmPyNode
:
public
TileOperatorNode
{
class
GemmPyNode
:
public
TileOperatorNode
{
public:
public:
bool
CheckWGMMA
()
const
;
bool
checkWgmma
()
const
;
tir
::
Buffer
A
,
B
,
C
;
bool
allowTcgen5Mma
(
Target
target
)
const
;
// pointer to the A, B, C
bool
allowWgmma
(
int
block_size
,
Target
target
)
const
;
PrimExpr
Aptr
,
Bptr
,
Cptr
;
tir
::
Buffer
a_
,
b_
,
c_
;
bool
trans_A
,
trans_B
;
// BufferRegion for A, B and C
int
M
,
N
,
K
;
BufferRegion
aRegion_
,
bRegion_
,
cRegion_
;
int
stride_A
,
stride_B
;
bool
transA_
,
transB_
;
int
offset_A
,
offset_B
;
int
m_
,
n_
,
k_
;
PrimExpr
clear_accum
=
const_false
();
int
strideA_
,
strideB_
;
int
offsetA_
,
offsetB_
;
PrimExpr
clearAccum_
=
const_false
();
PrimExpr
mbarPtr_
;
std
::
optional
<
tir
::
Buffer
>
mbar_
;
// mbar is optional, only used for TCGEN5MMA
Array
<
PrimExpr
>
cCoords_
;
// k_pack please ref to bitblas/tl/mfma_macro_generator.py::k_pack
// k_pack please ref to bitblas/tl/mfma_macro_generator.py::k_pack
// only will be enabled under cdna mfma instructions
// only will be enabled under cdna mfma instructions
int
kPack
=
1
;
int
kPack
_
=
1
;
int
wg
_w
ait
=
0
;
int
wg
W
ait
_
=
0
;
mutable
GemmWarpPolicy
policy
;
mutable
GemmWarpPolicy
policy
_
;
static
constexpr
const
char
*
_type_key
=
"tl.GemmPy"
;
TVM_FFI_DECLARE_OBJECT_INFO_FINAL
(
"tl.GemmPy"
,
GemmPyNode
,
TileOperatorNode
);
TVM_DECLARE_FINAL_OBJECT_INFO
(
GemmPyNode
,
TileOperatorNode
);
static
void
RegisterReflection
()
{
static
void
RegisterReflection
()
{
namespace
refl
=
tvm
::
ffi
::
reflection
;
namespace
refl
=
tvm
::
ffi
::
reflection
;
refl
::
ObjectDef
<
GemmPyNode
>
()
refl
::
ObjectDef
<
GemmPyNode
>
()
.
def_ro
(
"A"
,
&
GemmPyNode
::
A
)
.
def_ro
(
"a"
,
&
GemmPyNode
::
a_
)
.
def_ro
(
"B"
,
&
GemmPyNode
::
B
)
.
def_ro
(
"b"
,
&
GemmPyNode
::
b_
)
.
def_ro
(
"C"
,
&
GemmPyNode
::
C
)
.
def_ro
(
"c"
,
&
GemmPyNode
::
c_
)
.
def_ro
(
"Aptr"
,
&
GemmPyNode
::
Aptr
)
.
def_ro
(
"aRegion"
,
&
GemmPyNode
::
aRegion_
)
.
def_ro
(
"Bptr"
,
&
GemmPyNode
::
Bptr
)
.
def_ro
(
"bRegion"
,
&
GemmPyNode
::
bRegion_
)
.
def_ro
(
"Cptr"
,
&
GemmPyNode
::
Cptr
)
.
def_ro
(
"cRegion"
,
&
GemmPyNode
::
cRegion_
)
.
def_ro
(
"trans_A"
,
&
GemmPyNode
::
trans_A
)
.
def_ro
(
"transA"
,
&
GemmPyNode
::
transA_
)
.
def_ro
(
"trans_B"
,
&
GemmPyNode
::
trans_B
)
.
def_ro
(
"transB"
,
&
GemmPyNode
::
transB_
)
.
def_ro
(
"M"
,
&
GemmPyNode
::
M
)
.
def_ro
(
"m"
,
&
GemmPyNode
::
m_
)
.
def_ro
(
"N"
,
&
GemmPyNode
::
N
)
.
def_ro
(
"n"
,
&
GemmPyNode
::
n_
)
.
def_ro
(
"K"
,
&
GemmPyNode
::
K
)
.
def_ro
(
"k"
,
&
GemmPyNode
::
k_
)
.
def_ro
(
"stride_A"
,
&
GemmPyNode
::
stride_A
)
.
def_ro
(
"strideA"
,
&
GemmPyNode
::
strideA_
)
.
def_ro
(
"stride_B"
,
&
GemmPyNode
::
stride_B
)
.
def_ro
(
"strideB"
,
&
GemmPyNode
::
strideB_
)
.
def_ro
(
"offset_A"
,
&
GemmPyNode
::
offset_A
)
.
def_ro
(
"offsetA"
,
&
GemmPyNode
::
offsetA_
)
.
def_ro
(
"offset_B"
,
&
GemmPyNode
::
offset_B
)
.
def_ro
(
"offsetB"
,
&
GemmPyNode
::
offsetB_
)
.
def_ro
(
"clear_accum"
,
&
GemmPyNode
::
clear_accum
)
.
def_ro
(
"clearAccum"
,
&
GemmPyNode
::
clearAccum_
)
.
def_ro
(
"kPack"
,
&
GemmPyNode
::
kPack
)
.
def_ro
(
"mbarPtr"
,
&
GemmPyNode
::
mbarPtr_
)
.
def_ro
(
"wg_wait"
,
&
GemmPyNode
::
wg_wait
)
.
def_ro
(
"cCoords"
,
&
GemmPyNode
::
cCoords_
)
.
def_ro
(
"policy"
,
&
GemmPyNode
::
policy
);
.
def_ro
(
"kPack"
,
&
GemmPyNode
::
kPack_
)
.
def_ro
(
"wgWait"
,
&
GemmPyNode
::
wgWait_
)
.
def_ro
(
"policy"
,
&
GemmPyNode
::
policy_
);
}
}
bool
SEqualReduce
(
const
GemmPyNode
*
other
,
SEqualReducer
equal
)
const
{
return
equal
(
A
,
other
->
A
)
&&
equal
(
B
,
other
->
B
)
&&
equal
(
C
,
other
->
C
)
&&
equal
(
Aptr
,
other
->
Aptr
)
&&
equal
(
Bptr
,
other
->
Bptr
)
&&
equal
(
Cptr
,
other
->
Cptr
)
&&
equal
(
trans_A
,
other
->
trans_A
)
&&
equal
(
trans_B
,
other
->
trans_B
)
&&
equal
(
M
,
other
->
M
)
&&
equal
(
N
,
other
->
N
)
&&
equal
(
K
,
other
->
K
)
&&
equal
(
stride_A
,
other
->
stride_A
)
&&
equal
(
stride_B
,
other
->
stride_B
)
&&
equal
(
offset_A
,
other
->
offset_B
)
&&
equal
(
offset_B
,
other
->
offset_B
)
&&
equal
(
clear_accum
,
other
->
clear_accum
)
&&
equal
(
kPack
,
other
->
kPack
)
&&
equal
(
wg_wait
,
other
->
wg_wait
)
&&
equal
(
policy
,
other
->
policy
);
}
void
SHashReduce
(
SHashReducer
hash_reduce
)
const
{
hash_reduce
(
A
);
hash_reduce
(
B
);
hash_reduce
(
C
);
hash_reduce
(
Aptr
);
hash_reduce
(
Bptr
);
hash_reduce
(
Cptr
);
hash_reduce
(
trans_A
);
hash_reduce
(
trans_B
);
hash_reduce
(
M
);
hash_reduce
(
N
);
hash_reduce
(
K
);
hash_reduce
(
stride_A
);
hash_reduce
(
stride_B
);
hash_reduce
(
offset_A
);
hash_reduce
(
offset_B
);
hash_reduce
(
clear_accum
);
hash_reduce
(
kPack
);
hash_reduce
(
wg_wait
);
hash_reduce
(
policy
);
}
static
constexpr
bool
_type_has_method_sequal_reduce
=
true
;
static
constexpr
bool
_type_has_method_shash_reduce
=
true
;
Stmt
Lower
(
const
LowerArgs
&
T
,
arith
::
Analyzer
*
analyzer
)
const
override
;
Stmt
Lower
(
const
LowerArgs
&
T
,
arith
::
Analyzer
*
analyzer
)
const
override
;
LayoutMap
InferLayout
(
const
LayoutInferArgs
&
T
,
LayoutMap
InferLayout
(
const
LayoutInferArgs
&
T
,
InferLevel
level
)
const
override
;
InferLevel
level
)
const
override
;
...
@@ -106,7 +73,7 @@ public:
...
@@ -106,7 +73,7 @@ public:
TileOperator
Clone
()
const
;
TileOperator
Clone
()
const
;
// Target GEMM instruction
// Target GEMM instruction
GemmInst
G
etGemmInst
(
int
block_size
,
Target
target
)
const
;
GemmInst
g
etGemmInst
(
int
block_size
,
Target
target
)
const
;
private:
private:
mutable
bool
completed_
=
false
;
mutable
bool
completed_
=
false
;
...
@@ -114,7 +81,7 @@ private:
...
@@ -114,7 +81,7 @@ private:
class
GemmPy
:
public
TileOperator
{
class
GemmPy
:
public
TileOperator
{
public:
public:
TVM_DEFINE_OBJECT_REF_METHODS
(
GemmPy
,
TileOperator
,
GemmPyNode
);
TVM_
FFI_
DEFINE_OBJECT_REF_METHODS
_NULLABLE
(
GemmPy
,
TileOperator
,
GemmPyNode
);
TVM_DLL
GemmPy
(
Array
<
PrimExpr
>
args
,
BufferMap
vmap
);
TVM_DLL
GemmPy
(
Array
<
PrimExpr
>
args
,
BufferMap
vmap
);
static
const
Op
&
Get
();
static
const
Op
&
Get
();
};
};
...
@@ -122,4 +89,4 @@ public:
...
@@ -122,4 +89,4 @@ public:
}
// namespace tl
}
// namespace tl
}
// namespace tvm
}
// namespace tvm
#endif // TVM_TL_OP_GEMM_PY_H_
#endif // TVM_TL_OP_GEMM_PY_H_
\ No newline at end of file
src/op/gemm_sp.cc
View file @
bbbf4207
...
@@ -18,14 +18,14 @@
...
@@ -18,14 +18,14 @@
namespace
tvm
{
namespace
tvm
{
namespace
tl
{
namespace
tl
{
std
::
pair
<
int
,
int
>
GemmSPWarpPolicyNode
::
C
omputeWarpPartition
(
int
M
,
int
N
,
std
::
pair
<
int
,
int
>
GemmSPWarpPolicyNode
::
c
omputeWarpPartition
(
int
M
,
int
N
,
int
block_size
,
int
block_size
,
Target
target
,
Target
target
,
bool
use_wgmma
,
bool
use_wgmma
,
int
bits
)
const
{
int
bits
)
const
{
int
num_warps
=
block_size
/
TargetGetWarpSize
(
target
);
int
num_warps
=
block_size
/
TargetGetWarpSize
(
target
);
auto
[
m_warp
,
n_warp
]
=
GemmWarpPolicyNode
::
C
omputeWarpPartition
(
auto
[
m_warp
,
n_warp
]
=
GemmWarpPolicyNode
::
c
omputeWarpPartition
(
M
,
N
,
block_size
,
target
,
use_wgmma
?
GemmInst
::
kWGMMA
:
GemmInst
::
kMMA
);
M
,
N
,
block_size
,
target
,
use_wgmma
?
GemmInst
::
kWGMMA
:
GemmInst
::
kMMA
);
// Special handling for gemm_sp when the tiling size is not a multiple
// Special handling for gemm_sp when the tiling size is not a multiple
...
@@ -84,26 +84,26 @@ std::pair<int, int> GemmSPWarpPolicyNode::ComputeWarpPartition(int M, int N,
...
@@ -84,26 +84,26 @@ std::pair<int, int> GemmSPWarpPolicyNode::ComputeWarpPartition(int M, int N,
* @note An ICHECK failure is raised if a provided kPack is not 1 or 2.
* @note An ICHECK failure is raised if a provided kPack is not 1 or 2.
*/
*/
GemmSP
::
GemmSP
(
Array
<
PrimExpr
>
args
,
BufferMap
vmap
)
{
GemmSP
::
GemmSP
(
Array
<
PrimExpr
>
args
,
BufferMap
vmap
)
{
ObjectPtr
<
GemmSPNode
>
node
=
make_object
<
GemmSPNode
>
();
ObjectPtr
<
GemmSPNode
>
node
=
tvm
::
ffi
::
make_object
<
GemmSPNode
>
();
node
->
A
=
vmap
[
GetVarFromAccessPtr
(
args
[
0
])];
node
->
a_
=
vmap
[
GetVarFromAccessPtr
(
args
[
0
])];
node
->
E
=
vmap
[
GetVarFromAccessPtr
(
args
[
1
])];
node
->
e_
=
vmap
[
GetVarFromAccessPtr
(
args
[
1
])];
node
->
B
=
vmap
[
GetVarFromAccessPtr
(
args
[
2
])];
node
->
b_
=
vmap
[
GetVarFromAccessPtr
(
args
[
2
])];
node
->
C
=
vmap
[
GetVarFromAccessPtr
(
args
[
3
])];
node
->
c_
=
vmap
[
GetVarFromAccessPtr
(
args
[
3
])];
node
->
trans
_
A
=
args
[
4
].
as
<
Bool
>
().
value
();
node
->
transA
_
=
args
[
4
].
as
<
Bool
>
().
value
();
node
->
trans
_
B
=
args
[
5
].
as
<
Bool
>
().
value
();
node
->
transB
_
=
args
[
5
].
as
<
Bool
>
().
value
();
node
->
M
=
args
[
6
].
as
<
IntImm
>
().
value
()
->
value
;
node
->
m_
=
args
[
6
].
as
<
IntImm
>
().
value
()
->
value
;
node
->
N
=
args
[
7
].
as
<
IntImm
>
().
value
()
->
value
;
node
->
n_
=
args
[
7
].
as
<
IntImm
>
().
value
()
->
value
;
node
->
K
=
args
[
8
].
as
<
IntImm
>
().
value
()
->
value
;
node
->
k_
=
args
[
8
].
as
<
IntImm
>
().
value
()
->
value
;
node
->
policy
=
GemmSPWarpPolicy
(
args
[
9
].
as
<
IntImm
>
().
value
()
->
value
);
node
->
policy
_
=
GemmSPWarpPolicy
(
args
[
9
].
as
<
IntImm
>
().
value
()
->
value
);
node
->
clear
_a
ccum
=
args
[
10
].
as
<
Bool
>
().
value
();
node
->
clear
A
ccum
_
=
args
[
10
].
as
<
Bool
>
().
value
();
if
(
args
.
size
()
>
11
)
{
if
(
args
.
size
()
>
11
)
{
node
->
kPack
=
args
[
11
].
as
<
IntImm
>
().
value
()
->
value
;
node
->
kPack
_
=
args
[
11
].
as
<
IntImm
>
().
value
()
->
value
;
if
(
node
->
kPack
!=
1
&&
node
->
kPack
!=
2
)
{
if
(
node
->
kPack
_
!=
1
&&
node
->
kPack
_
!=
2
)
{
ICHECK
(
false
)
<<
"kPack must be 1 or 2"
;
ICHECK
(
false
)
<<
"kPack must be 1 or 2"
;
}
}
}
}
if
(
args
.
size
()
>
12
)
{
if
(
args
.
size
()
>
12
)
{
node
->
wg
_w
ait
=
args
[
12
].
as
<
IntImm
>
().
value
()
->
value
;
node
->
wg
W
ait
_
=
args
[
12
].
as
<
IntImm
>
().
value
()
->
value
;
}
}
data_
=
std
::
move
(
node
);
data_
=
std
::
move
(
node
);
}
}
...
@@ -118,7 +118,7 @@ GemmSP::GemmSP(Array<PrimExpr> args, BufferMap vmap) {
...
@@ -118,7 +118,7 @@ GemmSP::GemmSP(Array<PrimExpr> args, BufferMap vmap) {
* @return TileOperator A TileOperator holding a cloned GemmSPNode.
* @return TileOperator A TileOperator holding a cloned GemmSPNode.
*/
*/
TileOperator
GemmSPNode
::
Clone
()
const
{
TileOperator
GemmSPNode
::
Clone
()
const
{
auto
op
=
make_object
<
GemmSPNode
>
(
*
this
);
auto
op
=
tvm
::
ffi
::
make_object
<
GemmSPNode
>
(
*
this
);
return
GemmSP
(
op
);
return
GemmSP
(
op
);
}
}
...
@@ -144,37 +144,37 @@ Stmt GemmSPNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
...
@@ -144,37 +144,37 @@ Stmt GemmSPNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
int
warp_size
=
32
;
int
warp_size
=
32
;
auto
block_size
=
*
as_const_int
(
T
.
thread_bounds
->
extent
);
auto
block_size
=
*
as_const_int
(
T
.
thread_bounds
->
extent
);
bool
maybe_wgmma
=
TargetIsHopper
(
T
.
target
)
&&
(
this
->
M
>=
64
)
&&
bool
maybe_wgmma
=
TargetIsHopper
(
T
.
target
)
&&
(
this
->
m_
>=
64
)
&&
(
block_size
/
warp_size
%
4
==
0
);
(
block_size
/
warp_size
%
4
==
0
);
auto
[
warp_m
,
warp_n
]
=
policy
->
C
omputeWarpPartition
(
auto
[
warp_m
,
warp_n
]
=
policy
_
->
c
omputeWarpPartition
(
M
,
N
,
block_size
,
T
.
target
,
maybe_wgmma
,
A
->
dtype
.
bits
());
m_
,
n_
,
block_size
,
T
.
target
,
maybe_wgmma
,
a_
->
dtype
.
bits
());
std
::
stringstream
ss
;
std
::
stringstream
ss
;
std
::
string
op_name
=
"tl::gemm_sp_ss"
;
std
::
string
op_name
=
"tl::gemm_sp_ss"
;
ICHECK
((
A
.
scope
()
==
"shared"
||
A
.
scope
()
==
"shared.dyn"
)
&&
ICHECK
((
a_
.
scope
()
==
"shared"
||
a_
.
scope
()
==
"shared.dyn"
)
&&
(
B
.
scope
()
==
"shared"
||
B
.
scope
()
==
"shared.dyn"
))
(
b_
.
scope
()
==
"shared"
||
b_
.
scope
()
==
"shared.dyn"
))
<<
"Only support shared.dyn scope for A and B, but received "
<<
A
.
scope
()
<<
"Only support shared.dyn scope for A and B, but received "
<<
" and "
<<
B
.
scope
();
<<
a_
.
scope
()
<<
" and "
<<
b_
.
scope
();
ICHECK
((
E
.
scope
()
==
"shared"
||
E
.
scope
()
==
"shared.dyn"
))
ICHECK
((
e_
.
scope
()
==
"shared"
||
e_
.
scope
()
==
"shared.dyn"
))
<<
"Only support shared.dyn scope for E as copy from smem to rmem are "
<<
"Only support shared.dyn scope for E as copy from smem to rmem are "
"delegated to cute implementation, found "
"delegated to cute implementation, found "
<<
E
.
scope
();
<<
e_
.
scope
();
ss
<<
op_name
<<
"<"
<<
M
<<
", "
<<
N
<<
", "
<<
K
<<
", "
;
ss
<<
op_name
<<
"<"
<<
m_
<<
", "
<<
n_
<<
", "
<<
k_
<<
", "
;
ss
<<
warp_m
<<
", "
<<
warp_n
<<
", "
;
ss
<<
warp_m
<<
", "
<<
warp_n
<<
", "
;
ss
<<
trans
_
A
<<
", "
<<
trans
_
B
;
ss
<<
transA
_
<<
", "
<<
transB
_
;
ss
<<
", "
<<
clear
_a
ccum
;
ss
<<
", "
<<
clear
A
ccum
_
;
if
(
TargetIsHopper
(
T
.
target
))
{
if
(
TargetIsHopper
(
T
.
target
))
{
ss
<<
", "
<<
(
maybe_wgmma
?
"true"
:
"false"
);
ss
<<
", "
<<
(
maybe_wgmma
?
"true"
:
"false"
);
}
}
if
(
wg
_w
ait
!=
0
)
{
if
(
wg
W
ait
_
!=
0
)
{
ss
<<
", "
<<
wg
_w
ait
;
ss
<<
", "
<<
wg
W
ait
_
;
}
}
ss
<<
">"
;
ss
<<
">"
;
auto
A_buffer
=
T
.
buffer_remap
.
count
(
A
)
?
T
.
buffer_remap
[
A
]
:
A
;
auto
A_buffer
=
T
.
buffer_remap
.
count
(
a_
)
?
T
.
buffer_remap
[
a_
]
:
a_
;
auto
B_buffer
=
T
.
buffer_remap
.
count
(
B
)
?
T
.
buffer_remap
[
B
]
:
B
;
auto
B_buffer
=
T
.
buffer_remap
.
count
(
b_
)
?
T
.
buffer_remap
[
b_
]
:
b_
;
auto
C_buffer
=
T
.
buffer_remap
[
C
];
auto
C_buffer
=
T
.
buffer_remap
[
c_
];
auto
E_buffer
=
T
.
buffer_remap
.
count
(
E
)
?
T
.
buffer_remap
[
E
]
:
E
;
auto
E_buffer
=
T
.
buffer_remap
.
count
(
e_
)
?
T
.
buffer_remap
[
e_
]
:
e_
;
auto
new_call
=
auto
new_call
=
Call
(
DataType
::
Handle
(),
tl
::
tl_gemm_sp
(),
Call
(
DataType
::
Handle
(),
tl
::
tl_gemm_sp
(),
...
@@ -217,59 +217,59 @@ LayoutMap GemmSPNode::InferLayout(const LayoutInferArgs &T,
...
@@ -217,59 +217,59 @@ LayoutMap GemmSPNode::InferLayout(const LayoutInferArgs &T,
if
(
completed_
)
if
(
completed_
)
return
{};
return
{};
LayoutMap
results
;
LayoutMap
results
;
ICHECK
(
C
.
scope
()
==
"local.fragment"
);
ICHECK
(
c_
.
scope
()
==
"local.fragment"
);
auto
thread_range
=
T
.
thread_bounds
;
auto
thread_range
=
T
.
thread_bounds
;
auto
block_size
=
*
as_const_int
(
thread_range
->
extent
);
auto
block_size
=
*
as_const_int
(
thread_range
->
extent
);
if
(
TargetIsHopper
(
T
.
target
))
{
if
(
TargetIsHopper
(
T
.
target
))
{
const
int
warp_size
=
32
;
const
int
warp_size
=
32
;
constexpr
int
wgmma_m
=
16
*
4
;
constexpr
int
wgmma_m
=
16
*
4
;
bool
maybe_wgmma
=
bool
maybe_wgmma
=
(
this
->
M
>=
wgmma_m
)
&&
(
block_size
/
warp_size
%
4
==
0
);
(
this
->
m_
>=
wgmma_m
)
&&
(
block_size
/
warp_size
%
4
==
0
);
auto
[
warp_m
,
warp_n
]
=
policy
->
C
omputeWarpPartition
(
auto
[
warp_m
,
warp_n
]
=
policy
_
->
c
omputeWarpPartition
(
M
,
N
,
block_size
,
T
.
target
,
maybe_wgmma
,
A
->
dtype
.
bits
());
m_
,
n_
,
block_size
,
T
.
target
,
maybe_wgmma
,
a_
->
dtype
.
bits
());
auto
fragment
=
auto
fragment
=
maybe_wgmma
maybe_wgmma
?
makeGemmFragmentCHopper
(
m_
,
n_
,
m_
/
warp_m
,
?
makeGemmFragmentCHopper
(
M
,
N
,
M
/
warp_m
,
N
/
warp_n
,
n_
/
warp_n
,
c_
->
dtype
.
bits
())
C
->
dtype
.
bits
())
:
makeGemmFragmentC
(
m_
,
n_
,
m_
/
warp_m
,
n_
/
warp_n
,
:
makeGemmFragmentC
(
M
,
N
,
M
/
warp_m
,
N
/
warp_n
,
C
->
dtype
.
bits
());
c_
->
dtype
.
bits
());
results
.
Set
(
C
,
fragment
->
BindThreadRange
(
thread_range
));
results
.
Set
(
c_
,
fragment
->
BindThreadRange
(
thread_range
));
if
(
A
.
scope
()
==
"shared"
||
A
.
scope
()
==
"shared.dyn"
)
{
if
(
a_
.
scope
()
==
"shared"
||
a_
.
scope
()
==
"shared.dyn"
)
{
int
dim_A
=
A
->
shape
.
size
();
int
dim_A
=
a_
->
shape
.
size
();
const
int64_t
mat_stride
=
*
as_const_int
(
A
->
shape
[
dim_A
-
2
]);
const
int64_t
mat_stride
=
*
as_const_int
(
a_
->
shape
[
dim_A
-
2
]);
const
int64_t
mat_continuous
=
*
as_const_int
(
A
->
shape
[
dim_A
-
1
]);
const
int64_t
mat_continuous
=
*
as_const_int
(
a_
->
shape
[
dim_A
-
1
]);
results
.
Set
(
A
,
makeGemmABLayoutHopper
(
mat_stride
,
mat_continuous
,
results
.
Set
(
a_
,
makeGemmABLayoutHopper
(
mat_stride
,
mat_continuous
,
mat_continuous
,
A
->
dtype
.
bits
(),
mat_continuous
,
a_
->
dtype
.
bits
(),
trans
_
A
?
1
:
2
));
transA
_
?
1
:
2
));
}
else
{
}
else
{
ICHECK
(
false
)
<<
"Not implemented"
;
ICHECK
(
false
)
<<
"Not implemented"
;
}
}
if
(
B
.
scope
()
==
"shared"
||
B
.
scope
()
==
"shared.dyn"
)
{
if
(
b_
.
scope
()
==
"shared"
||
b_
.
scope
()
==
"shared.dyn"
)
{
int
dim_B
=
B
->
shape
.
size
();
int
dim_B
=
b_
->
shape
.
size
();
const
int64_t
mat_stride
=
*
as_const_int
(
B
->
shape
[
dim_B
-
2
]);
const
int64_t
mat_stride
=
*
as_const_int
(
b_
->
shape
[
dim_B
-
2
]);
const
int64_t
mat_continuous
=
*
as_const_int
(
B
->
shape
[
dim_B
-
1
]);
const
int64_t
mat_continuous
=
*
as_const_int
(
b_
->
shape
[
dim_B
-
1
]);
const
int64_t
continuity
=
const
int64_t
continuity
=
trans
_
B
?
mat_continuous
:
mat_continuous
/
warp_n
;
transB
_
?
mat_continuous
:
mat_continuous
/
warp_n
;
results
.
Set
(
B
,
results
.
Set
(
b_
,
makeGemmABLayoutHopper
(
mat_stride
,
mat_continuous
,
continuity
,
makeGemmABLayoutHopper
(
mat_stride
,
mat_continuous
,
continuity
,
B
->
dtype
.
bits
(),
trans
_
B
?
2
:
1
));
b_
->
dtype
.
bits
(),
transB
_
?
2
:
1
));
}
else
{
}
else
{
ICHECK
(
false
)
<<
"WGMMA only support B in shared."
;
ICHECK
(
false
)
<<
"WGMMA only support B in shared."
;
}
}
}
else
if
(
TargetIsAmpere
(
T
.
target
))
{
}
else
if
(
TargetIsAmpere
(
T
.
target
))
{
auto
[
warp_m
,
warp_n
]
=
policy
->
C
omputeWarpPartition
(
auto
[
warp_m
,
warp_n
]
=
policy
_
->
c
omputeWarpPartition
(
M
,
N
,
block_size
,
T
.
target
,
false
,
A
->
dtype
.
bits
());
m_
,
n_
,
block_size
,
T
.
target
,
false
,
a_
->
dtype
.
bits
());
auto
fragment
=
auto
fragment
=
makeGemmSparseFragmentC
(
m_
,
n_
,
m_
/
warp_m
,
n_
/
warp_n
,
makeGemmSparseFragmentC
(
M
,
N
,
M
/
warp_m
,
N
/
warp_n
,
C
->
dtype
.
bits
());
c_
->
dtype
.
bits
());
results
.
Set
(
C
,
fragment
->
BindThreadRange
(
thread_range
));
results
.
Set
(
c_
,
fragment
->
BindThreadRange
(
thread_range
));
if
(
A
.
scope
()
==
"shared"
||
A
.
scope
()
==
"shared.dyn"
)
{
if
(
a_
.
scope
()
==
"shared"
||
a_
.
scope
()
==
"shared.dyn"
)
{
int
dim_A
=
A
->
shape
.
size
();
int
dim_A
=
a_
->
shape
.
size
();
const
int64_t
mat_stride
=
*
as_const_int
(
A
->
shape
[
dim_A
-
2
]);
const
int64_t
mat_stride
=
*
as_const_int
(
a_
->
shape
[
dim_A
-
2
]);
const
int64_t
mat_continuous
=
*
as_const_int
(
A
->
shape
[
dim_A
-
1
]);
const
int64_t
mat_continuous
=
*
as_const_int
(
a_
->
shape
[
dim_A
-
1
]);
results
.
Set
(
A
,
makeGemmSparseAmpereABLayout
(
mat_stride
,
mat_continuous
,
results
.
Set
(
a_
,
makeGemmSparseAmpereABLayout
(
mat_stride
,
mat_continuous
,
A
->
dtype
.
bits
()));
a_
->
dtype
.
bits
()));
}
else
if
(
A
.
scope
()
==
"local.fragment"
)
{
}
else
if
(
a_
.
scope
()
==
"local.fragment"
)
{
// auto fragment = makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n,
// auto fragment = makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n,
// A->dtype.bits(), trans_A);
// A->dtype.bits(), trans_A);
// results.Set(A, fragment->BindThreadRange(thread_range));
// results.Set(A, fragment->BindThreadRange(thread_range));
...
@@ -277,13 +277,13 @@ LayoutMap GemmSPNode::InferLayout(const LayoutInferArgs &T,
...
@@ -277,13 +277,13 @@ LayoutMap GemmSPNode::InferLayout(const LayoutInferArgs &T,
}
else
{
}
else
{
ICHECK
(
0
);
ICHECK
(
0
);
}
}
if
(
B
.
scope
()
==
"shared"
||
B
.
scope
()
==
"shared.dyn"
)
{
if
(
b_
.
scope
()
==
"shared"
||
b_
.
scope
()
==
"shared.dyn"
)
{
int
dim_B
=
B
->
shape
.
size
();
int
dim_B
=
b_
->
shape
.
size
();
const
int64_t
mat_stride
=
*
as_const_int
(
B
->
shape
[
dim_B
-
2
]);
const
int64_t
mat_stride
=
*
as_const_int
(
b_
->
shape
[
dim_B
-
2
]);
const
int64_t
mat_continuous
=
*
as_const_int
(
B
->
shape
[
dim_B
-
1
]);
const
int64_t
mat_continuous
=
*
as_const_int
(
b_
->
shape
[
dim_B
-
1
]);
results
.
Set
(
B
,
makeGemmSparseAmpereABLayout
(
mat_stride
,
mat_continuous
,
results
.
Set
(
b_
,
makeGemmSparseAmpereABLayout
(
mat_stride
,
mat_continuous
,
B
->
dtype
.
bits
()));
b_
->
dtype
.
bits
()));
}
else
if
(
B
.
scope
()
==
"local.fragment"
)
{
}
else
if
(
b_
.
scope
()
==
"local.fragment"
)
{
// auto fragment =
// auto fragment =
// makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n, trans_B);
// makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n, trans_B);
// results.Set(B, fragment->BindThreadRange(thread_range));
// results.Set(B, fragment->BindThreadRange(thread_range));
...
@@ -303,7 +303,7 @@ TIR_REGISTER_TL_OP(GemmSP, gemm_sp)
...
@@ -303,7 +303,7 @@ TIR_REGISTER_TL_OP(GemmSP, gemm_sp)
.
set_attr
<
TCallEffectKind
>
(
"TCallEffectKind"
,
.
set_attr
<
TCallEffectKind
>
(
"TCallEffectKind"
,
Integer
(
CallEffectKind
::
kOpaque
));
Integer
(
CallEffectKind
::
kOpaque
));
TVM_FFI_STATIC_INIT_BLOCK
({
GemmSPNode
::
RegisterReflection
();
}
);
TVM_FFI_STATIC_INIT_BLOCK
(
)
{
GemmSPNode
::
RegisterReflection
();
}
}
// namespace tl
}
// namespace tl
}
// namespace tvm
}
// namespace tvm
src/op/gemm_sp.h
View file @
bbbf4207
...
@@ -18,30 +18,32 @@ using namespace tir;
...
@@ -18,30 +18,32 @@ using namespace tir;
class
GemmSPWarpPolicyNode
:
public
GemmWarpPolicyNode
{
class
GemmSPWarpPolicyNode
:
public
GemmWarpPolicyNode
{
public:
public:
std
::
pair
<
int
,
int
>
C
omputeWarpPartition
(
int
M
,
int
N
,
int
block_size
,
std
::
pair
<
int
,
int
>
c
omputeWarpPartition
(
int
M
,
int
N
,
int
block_size
,
Target
target
,
bool
use_wgmma
,
Target
target
,
bool
use_wgmma
,
int
bits
)
const
;
int
bits
)
const
;
TVM_FFI_DECLARE_OBJECT_INFO
(
"tl.GemmSPWarpPolicy"
,
GemmSPWarpPolicyNode
,
GemmWarpPolicyNode
);
};
};
class
GemmSPWarpPolicy
:
public
ObjectRef
{
class
GemmSPWarpPolicy
:
public
ObjectRef
{
public:
public:
TVM_DEFINE_OBJECT_REF_METHODS
(
GemmSPWarpPolicy
,
ObjectRef
,
TVM_
FFI_
DEFINE_OBJECT_REF_METHODS
_NULLABLE
(
GemmSPWarpPolicy
,
ObjectRef
,
GemmSPWarpPolicyNode
);
GemmSPWarpPolicyNode
);
explicit
GemmSPWarpPolicy
(
GemmWarpPolicyType
policy_type
)
{
explicit
GemmSPWarpPolicy
(
GemmWarpPolicyType
policy_type
)
{
auto
node
=
make_object
<
GemmSPWarpPolicyNode
>
();
auto
node
=
tvm
::
ffi
::
make_object
<
GemmSPWarpPolicyNode
>
();
node
->
policy_type
=
(
int
)
policy_type
;
node
->
policy_type
=
(
int
)
policy_type
;
data_
=
std
::
move
(
node
);
data_
=
std
::
move
(
node
);
}
}
explicit
GemmSPWarpPolicy
(
int
policy_type
)
{
explicit
GemmSPWarpPolicy
(
int
policy_type
)
{
auto
node
=
make_object
<
GemmSPWarpPolicyNode
>
();
auto
node
=
tvm
::
ffi
::
make_object
<
GemmSPWarpPolicyNode
>
();
node
->
policy_type
=
policy_type
;
node
->
policy_type
=
policy_type
;
data_
=
std
::
move
(
node
);
data_
=
std
::
move
(
node
);
}
}
explicit
GemmSPWarpPolicy
(
int
m_warp
,
int
n_warp
)
{
explicit
GemmSPWarpPolicy
(
int
m_warp
,
int
n_warp
)
{
auto
node
=
make_object
<
GemmSPWarpPolicyNode
>
();
auto
node
=
tvm
::
ffi
::
make_object
<
GemmSPWarpPolicyNode
>
();
node
->
m_warp
=
m_warp
;
node
->
m_warp
=
m_warp
;
node
->
n_warp
=
n_warp
;
node
->
n_warp
=
n_warp
;
node
->
policy_type
=
(
int
)
GemmWarpPolicyType
::
kFree
;
node
->
policy_type
=
(
int
)
GemmWarpPolicyType
::
kFree
;
...
@@ -51,19 +53,18 @@ public:
...
@@ -51,19 +53,18 @@ public:
class
GemmSPNode
:
public
TileOperatorNode
{
class
GemmSPNode
:
public
TileOperatorNode
{
public:
public:
tir
::
Buffer
A
,
B
,
C
,
E
;
tir
::
Buffer
a_
,
b_
,
c_
,
e_
;
bool
trans
_
A
,
trans
_
B
;
bool
transA
_
,
transB
_
;
int
M
,
N
,
K
;
int
m_
,
n_
,
k_
;
bool
clear
_a
ccum
=
false
;
bool
clear
A
ccum
_
=
false
;
// k_pack please ref to bitblas/tl/mfma_macro_generator.py::k_pack
// k_pack please ref to bitblas/tl/mfma_macro_generator.py::k_pack
// only will be enabled under cdna mfma instructions
// only will be enabled under cdna mfma instructions
int
kPack
=
1
;
int
kPack
_
=
1
;
int
wg
_w
ait
=
0
;
int
wg
W
ait
_
=
0
;
mutable
GemmSPWarpPolicy
policy
;
mutable
GemmSPWarpPolicy
policy
_
;
static
constexpr
const
char
*
_type_key
=
"tl.GemmSP"
;
TVM_FFI_DECLARE_OBJECT_INFO_FINAL
(
"tl.GemmSP"
,
GemmSPNode
,
TileOperatorNode
);
TVM_DECLARE_FINAL_OBJECT_INFO
(
GemmSPNode
,
TileOperatorNode
);
Stmt
Lower
(
const
LowerArgs
&
T
,
arith
::
Analyzer
*
analyzer
)
const
override
;
Stmt
Lower
(
const
LowerArgs
&
T
,
arith
::
Analyzer
*
analyzer
)
const
override
;
LayoutMap
InferLayout
(
const
LayoutInferArgs
&
T
,
LayoutMap
InferLayout
(
const
LayoutInferArgs
&
T
,
InferLevel
level
)
const
override
;
InferLevel
level
)
const
override
;
...
@@ -73,44 +74,19 @@ public:
...
@@ -73,44 +74,19 @@ public:
static
void
RegisterReflection
()
{
static
void
RegisterReflection
()
{
namespace
refl
=
tvm
::
ffi
::
reflection
;
namespace
refl
=
tvm
::
ffi
::
reflection
;
refl
::
ObjectDef
<
GemmSPNode
>
()
refl
::
ObjectDef
<
GemmSPNode
>
()
.
def_ro
(
"policy"
,
&
GemmSPNode
::
policy
)
.
def_ro
(
"policy"
,
&
GemmSPNode
::
policy_
)
.
def_ro
(
"A"
,
&
GemmSPNode
::
A
)
.
def_ro
(
"a"
,
&
GemmSPNode
::
a_
)
.
def_ro
(
"B"
,
&
GemmSPNode
::
B
)
.
def_ro
(
"b"
,
&
GemmSPNode
::
b_
)
.
def_ro
(
"C"
,
&
GemmSPNode
::
C
)
.
def_ro
(
"c"
,
&
GemmSPNode
::
c_
)
.
def_ro
(
"E"
,
&
GemmSPNode
::
E
)
.
def_ro
(
"e"
,
&
GemmSPNode
::
e_
)
.
def_ro
(
"trans_A"
,
&
GemmSPNode
::
trans_A
)
.
def_ro
(
"transA"
,
&
GemmSPNode
::
transA_
)
.
def_ro
(
"trans_B"
,
&
GemmSPNode
::
trans_B
)
.
def_ro
(
"transB"
,
&
GemmSPNode
::
transB_
)
.
def_ro
(
"M"
,
&
GemmSPNode
::
M
)
.
def_ro
(
"m"
,
&
GemmSPNode
::
m_
)
.
def_ro
(
"N"
,
&
GemmSPNode
::
N
)
.
def_ro
(
"n"
,
&
GemmSPNode
::
n_
)
.
def_ro
(
"K"
,
&
GemmSPNode
::
K
)
.
def_ro
(
"k"
,
&
GemmSPNode
::
k_
)
.
def_ro
(
"clear_accum"
,
&
GemmSPNode
::
clear_accum
)
.
def_ro
(
"clearAccum"
,
&
GemmSPNode
::
clearAccum_
)
.
def_ro
(
"kPack"
,
&
GemmSPNode
::
kPack
)
.
def_ro
(
"kPack"
,
&
GemmSPNode
::
kPack_
)
.
def_ro
(
"wg_wait"
,
&
GemmSPNode
::
wg_wait
);
.
def_ro
(
"wgWait"
,
&
GemmSPNode
::
wgWait_
);
}
bool
SEqualReduce
(
const
GemmSPNode
*
other
,
SEqualReducer
equal
)
const
{
return
equal
(
A
,
other
->
A
)
&&
equal
(
B
,
other
->
B
)
&&
equal
(
C
,
other
->
C
)
&&
equal
(
E
,
other
->
E
)
&&
equal
(
trans_A
,
other
->
trans_A
)
&&
equal
(
trans_B
,
other
->
trans_B
)
&&
equal
(
M
,
other
->
M
)
&&
equal
(
N
,
other
->
N
)
&&
equal
(
K
,
other
->
K
)
&&
equal
(
clear_accum
,
other
->
clear_accum
)
&&
equal
(
kPack
,
other
->
kPack
)
&&
equal
(
wg_wait
,
other
->
wg_wait
);
}
void
SHashReduce
(
SHashReducer
hash_reduce
)
const
{
hash_reduce
(
policy
);
hash_reduce
(
A
);
hash_reduce
(
B
);
hash_reduce
(
C
);
hash_reduce
(
E
);
hash_reduce
(
trans_A
);
hash_reduce
(
trans_B
);
hash_reduce
(
M
);
hash_reduce
(
N
);
hash_reduce
(
K
);
hash_reduce
(
clear_accum
);
hash_reduce
(
kPack
);
hash_reduce
(
wg_wait
);
}
}
private:
private:
...
@@ -119,7 +95,7 @@ private:
...
@@ -119,7 +95,7 @@ private:
class
GemmSP
:
public
TileOperator
{
class
GemmSP
:
public
TileOperator
{
public:
public:
TVM_DEFINE_OBJECT_REF_METHODS
(
GemmSP
,
TileOperator
,
GemmSPNode
);
TVM_
FFI_
DEFINE_OBJECT_REF_METHODS
_NULLABLE
(
GemmSP
,
TileOperator
,
GemmSPNode
);
TVM_DLL
GemmSP
(
Array
<
PrimExpr
>
args
,
BufferMap
vmap
);
TVM_DLL
GemmSP
(
Array
<
PrimExpr
>
args
,
BufferMap
vmap
);
static
const
Op
&
Get
();
static
const
Op
&
Get
();
};
};
...
...
src/op/logical.cc
View file @
bbbf4207
...
@@ -9,6 +9,8 @@
...
@@ -9,6 +9,8 @@
#include <tvm/tir/op.h>
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>
#include <tvm/tir/op_attr_types.h>
#include "../support/ffi_aliases.h"
namespace
tvm
{
namespace
tvm
{
namespace
tl
{
namespace
tl
{
using
namespace
tir
;
using
namespace
tir
;
...
@@ -50,4 +52,4 @@ TVM_REGISTER_OP("tl.all_of")
...
@@ -50,4 +52,4 @@ TVM_REGISTER_OP("tl.all_of")
.
set_attr
<
FLowerIntrinsic
>
(
"cuda.FLowerIntrinsic"
,
all_of_op
);
.
set_attr
<
FLowerIntrinsic
>
(
"cuda.FLowerIntrinsic"
,
all_of_op
);
}
// namespace tl
}
// namespace tl
}
// namespace tvm
}
// namespace tvm
\ No newline at end of file
src/op/math.cc
View file @
bbbf4207
...
@@ -9,6 +9,8 @@
...
@@ -9,6 +9,8 @@
#include <tvm/tir/op.h>
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>
#include <tvm/tir/op_attr_types.h>
#include "../support/ffi_aliases.h"
namespace
tvm
{
namespace
tvm
{
namespace
tl
{
namespace
tl
{
using
namespace
tir
;
using
namespace
tir
;
...
@@ -33,5 +35,31 @@ TVM_REGISTER_OP("tl.pow_of_int")
...
@@ -33,5 +35,31 @@ TVM_REGISTER_OP("tl.pow_of_int")
.
set_attr
<
TScriptPrinterName
>
(
"TScriptPrinterName"
,
"pow_of_int"
)
.
set_attr
<
TScriptPrinterName
>
(
"TScriptPrinterName"
,
"pow_of_int"
)
.
set_attr
<
FLowerIntrinsic
>
(
"cuda.FLowerIntrinsic"
,
pow_of_int_op
);
.
set_attr
<
FLowerIntrinsic
>
(
"cuda.FLowerIntrinsic"
,
pow_of_int_op
);
PrimExpr
infinity_op
(
PrimExpr
args
)
{
const
CallNode
*
call
=
args
.
as
<
CallNode
>
();
CHECK
(
call
!=
nullptr
);
const
DataType
&
dtype
=
call
->
dtype
;
ICHECK_EQ
(
dtype
.
lanes
(),
1
);
// NOTE(wt): Codegen for PrintConst:Inf will handle this based on dtype
if
(
dtype
.
is_float
())
{
if
(
dtype
.
bits
()
==
64
||
dtype
.
bits
()
==
32
||
dtype
.
bits
()
==
16
)
{
return
FloatImm
(
dtype
,
std
::
numeric_limits
<
float
>::
infinity
(),
call
->
span
);
}
}
else
if
(
dtype
.
is_bfloat16
())
{
return
FloatImm
(
dtype
,
std
::
numeric_limits
<
float
>::
infinity
(),
call
->
span
);
}
LOG
(
FATAL
)
<<
"Cannot decide infinity for type "
<<
dtype
;
throw
;
// Unreachable, keeps compiler happy
}
TVM_REGISTER_OP
(
"tl.infinity"
)
.
set_num_inputs
(
1
)
.
set_attr
<
TCallEffectKind
>
(
"TCallEffectKind"
,
Integer
(
CallEffectKind
::
kPure
))
.
set_attr
<
TScriptPrinterName
>
(
"TScriptPrinterName"
,
"infinity"
)
.
set_attr
<
FLowerIntrinsic
>
(
"cuda.FLowerIntrinsic"
,
infinity_op
);
}
// namespace tl
}
// namespace tl
}
// namespace tvm
}
// namespace tvm
src/op/operator.cc
View file @
bbbf4207
...
@@ -55,7 +55,7 @@ TileOperator ParseOperator(Call call, BufferMap vmap) {
...
@@ -55,7 +55,7 @@ TileOperator ParseOperator(Call call, BufferMap vmap) {
TileOperator
ParseOperator
(
Stmt
stmt
,
BufferMap
vmap
)
{
TileOperator
ParseOperator
(
Stmt
stmt
,
BufferMap
vmap
)
{
if
(
stmt
.
as
<
Evaluate
>
()
&&
stmt
.
as
<
EvaluateNode
>
()
->
value
.
as
<
CallNode
>
())
{
if
(
stmt
.
as
<
Evaluate
>
()
&&
stmt
.
as
<
EvaluateNode
>
()
->
value
.
as
<
CallNode
>
())
{
auto
call
=
stmt
.
as
<
EvaluateNode
>
()
->
value
.
as
<
CallNode
>
();
auto
call
=
stmt
.
as
<
EvaluateNode
>
()
->
value
.
as
<
CallNode
>
();
return
ParseOperator
(
GetRef
<
Call
>
(
call
),
vmap
);
return
ParseOperator
(
tvm
::
ffi
::
GetRef
<
Call
>
(
call
),
vmap
);
}
}
return
TileOperator
();
return
TileOperator
();
}
}
...
@@ -77,7 +77,7 @@ Var GetVarFromAccessPtr(const PrimExpr &expr) {
...
@@ -77,7 +77,7 @@ Var GetVarFromAccessPtr(const PrimExpr &expr) {
ICHECK
(
call
->
op
.
same_as
(
builtin
::
tvm_access_ptr
()));
ICHECK
(
call
->
op
.
same_as
(
builtin
::
tvm_access_ptr
()));
auto
var
=
call
->
args
[
1
].
as
<
VarNode
>
();
auto
var
=
call
->
args
[
1
].
as
<
VarNode
>
();
ICHECK
(
var
);
ICHECK
(
var
);
return
GetRef
<
Var
>
(
var
);
return
tvm
::
ffi
::
GetRef
<
Var
>
(
var
);
}
}
}
// namespace tl
}
// namespace tl
...
...
src/op/operator.h
View file @
bbbf4207
...
@@ -39,7 +39,6 @@ struct LowerArgs {
...
@@ -39,7 +39,6 @@ struct LowerArgs {
AddWorkspaceCallback
AddWorkspace
;
AddWorkspaceCallback
AddWorkspace
;
LayoutMap
layout_map
;
LayoutMap
layout_map
;
Map
<
Buffer
,
Buffer
>
buffer_remap
;
Map
<
Buffer
,
Buffer
>
buffer_remap
;
Array
<
Var
>
buffer_var_gemm
;
};
};
struct
LayoutInferArgs
{
struct
LayoutInferArgs
{
...
@@ -62,14 +61,13 @@ public:
...
@@ -62,14 +61,13 @@ public:
virtual
TileOperator
Clone
()
const
=
0
;
virtual
TileOperator
Clone
()
const
=
0
;
static
constexpr
const
char
*
_type_key
=
"tl.TileOperator"
;
TVM_FFI_DECLARE_OBJECT_INFO
(
"tl.TileOperator"
,
TileOperatorNode
,
Object
);
TVM_DECLARE_BASE_OBJECT_INFO
(
TileOperatorNode
,
Object
);
};
};
class
TileOperator
:
public
ObjectRef
{
class
TileOperator
:
public
ObjectRef
{
public:
public:
TVM_DEFINE_OBJECT_REF_METHODS
(
TileOperator
,
ObjectRef
,
TileOperatorNode
);
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE
(
TileOperator
,
ObjectRef
,
TileOperatorNode
);
};
};
Var
GetVarFromAccessPtr
(
const
PrimExpr
&
expr
);
Var
GetVarFromAccessPtr
(
const
PrimExpr
&
expr
);
...
...
src/op/parallel.cc
View file @
bbbf4207
...
@@ -178,7 +178,7 @@ ParallelOpNode::ParallelOpNode(For root) : root_(root), V(this) {
...
@@ -178,7 +178,7 @@ ParallelOpNode::ParallelOpNode(For root) : root_(root), V(this) {
}
}
TileOperator
ParallelOpNode
::
Clone
()
const
{
TileOperator
ParallelOpNode
::
Clone
()
const
{
auto
op
=
make_object
<
ParallelOpNode
>
(
*
this
);
auto
op
=
tvm
::
ffi
::
make_object
<
ParallelOpNode
>
(
*
this
);
return
ParallelOp
(
op
);
return
ParallelOp
(
op
);
}
}
...
@@ -620,11 +620,37 @@ Fragment ParallelOpNode::CompleteBufferFragment(const Buffer &buffer) const {
...
@@ -620,11 +620,37 @@ Fragment ParallelOpNode::CompleteBufferFragment(const Buffer &buffer) const {
if
(
IsCommonAccessIndice
(
buffer
))
{
if
(
IsCommonAccessIndice
(
buffer
))
{
return
loop_layout_
;
return
loop_layout_
;
}
}
// Prefer a simple path: if original 2D indices form a bijective map, invert
// them directly and avoid introducing a synthetic replicate dimension.
{
auto
res2d
=
arith
::
DetectIterMap
(
indice_map_
[
buffer
],
ToVMap
(
loop_vars_
),
1
,
arith
::
IterMapLevel
::
Bijective
,
const_cast
<
arith
::
Analyzer
*>
(
&
analyzer_
));
if
(
res2d
->
errors
.
empty
())
{
Layout
ind_inv2d
=
Layout
(
loop_vars_
,
indice_map_
[
buffer
])
->
Inverse
();
PrimExpr
indice_rep_extent
=
1
;
PrimExpr
loop_rep_extent
=
loop_layout_
->
ReplicateExtent
();
PrimExpr
dest_buffer_rep_extent
=
indice_rep_extent
*
loop_rep_extent
;
Array
<
PrimExpr
>
fwd2
;
for
(
size_t
i
=
0
;
i
<
buffer
->
shape
.
size
();
i
++
)
{
fwd2
.
push_back
(
InputPlaceholder
(
i
));
}
PrimExpr
thd_b2
=
loop_layout_
->
ForwardThread
(
ind_inv2d
->
Forward
(
fwd2
),
std
::
nullopt
);
return
Fragment
(
buffer
->
shape
,
{},
thd_b2
,
dest_buffer_rep_extent
,
std
::
nullopt
)
->
CondenseReplicateVar
();
}
}
// Otherwise, infer an extra flattened iterator that captures truly-unused
// pieces of the loop space (if any), then try inversion with it.
PrimExpr
rep_b
=
MakeFlattenedExpression
(
PrimExpr
rep_b
=
MakeFlattenedExpression
(
DivideUnusedIterators
(
indice_map_
[
buffer
],
loop_vars_
,
&
analyzer_
));
DivideUnusedIterators
(
indice_map_
[
buffer
],
loop_vars_
,
&
analyzer_
));
auto
bijective_indice
=
indice_map_
[
buffer
];
auto
bijective_indice
=
indice_map_
[
buffer
];
bijective_indice
.
push_back
(
rep_b
);
bijective_indice
.
push_back
(
rep_b
);
Layout
ind_inv
=
Layout
(
loop_vars_
,
bijective_indice
)
->
Inverse
();
Layout
ind_inv
=
Layout
(
loop_vars_
,
bijective_indice
)
->
Inverse
();
PrimExpr
indice_rep_extent
=
PrimExpr
indice_rep_extent
=
ind_inv
->
InputShape
().
back
();
// this is the size of rep_b
ind_inv
->
InputShape
().
back
();
// this is the size of rep_b
PrimExpr
loop_rep_extent
=
loop_layout_
->
ReplicateExtent
();
PrimExpr
loop_rep_extent
=
loop_layout_
->
ReplicateExtent
();
...
@@ -642,7 +668,7 @@ Fragment ParallelOpNode::CompleteBufferFragment(const Buffer &buffer) const {
...
@@ -642,7 +668,7 @@ Fragment ParallelOpNode::CompleteBufferFragment(const Buffer &buffer) const {
->
CondenseReplicateVar
();
->
CondenseReplicateVar
();
}
}
TVM_FFI_STATIC_INIT_BLOCK
({
ParallelOpNode
::
RegisterReflection
();
}
);
TVM_FFI_STATIC_INIT_BLOCK
(
)
{
ParallelOpNode
::
RegisterReflection
();
}
}
// namespace tl
}
// namespace tl
}
// namespace tvm
}
// namespace tvm
src/op/parallel.h
View file @
bbbf4207
...
@@ -66,8 +66,8 @@ public:
...
@@ -66,8 +66,8 @@ public:
mutable
Optional
<
PrimExpr
>
predicate_
;
mutable
Optional
<
PrimExpr
>
predicate_
;
// Type key for TVM object system.
// Type key for TVM object system.
static
constexpr
const
char
*
_type_key
=
"tl.ParallelOp"
;
TVM_FFI_DECLARE_OBJECT_INFO_FINAL
(
"tl.ParallelOp"
,
ParallelOpNode
,
TVM_DECLARE_FINAL_OBJECT_INFO
(
ParallelOpNode
,
TileOperatorNode
);
TileOperatorNode
);
static
void
RegisterReflection
()
{
static
void
RegisterReflection
()
{
namespace
refl
=
tvm
::
ffi
::
reflection
;
namespace
refl
=
tvm
::
ffi
::
reflection
;
...
@@ -77,20 +77,6 @@ public:
...
@@ -77,20 +77,6 @@ public:
.
def_ro
(
"predicate"
,
&
ParallelOpNode
::
predicate_
);
.
def_ro
(
"predicate"
,
&
ParallelOpNode
::
predicate_
);
}
}
bool
SEqualReduce
(
const
ParallelOpNode
*
other
,
SEqualReducer
equal
)
const
{
return
equal
(
root_
,
other
->
root_
)
&&
equal
(
loop_layout_
,
other
->
loop_layout_
)
&&
equal
(
predicate_
,
other
->
predicate_
);
}
void
SHashReduce
(
SHashReducer
hash_reduce
)
const
{
hash_reduce
(
root_
);
hash_reduce
(
loop_layout_
);
hash_reduce
(
predicate_
);
}
static
constexpr
bool
_type_has_method_sequal_reduce
=
true
;
static
constexpr
bool
_type_has_method_shash_reduce
=
true
;
// Construct from a root For loop.
// Construct from a root For loop.
ParallelOpNode
(
For
root
);
ParallelOpNode
(
For
root
);
...
@@ -150,10 +136,11 @@ private:
...
@@ -150,10 +136,11 @@ private:
class
ParallelOp
:
public
TileOperator
{
class
ParallelOp
:
public
TileOperator
{
public:
public:
TVM_DEFINE_OBJECT_REF_METHODS
(
ParallelOp
,
TileOperator
,
ParallelOpNode
);
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE
(
ParallelOp
,
TileOperator
,
ParallelOpNode
);
ParallelOp
(
const
For
&
root
)
{
ParallelOp
(
const
For
&
root
)
{
auto
op
=
make_object
<
ParallelOpNode
>
(
root
);
auto
op
=
tvm
::
ffi
::
make_object
<
ParallelOpNode
>
(
root
);
data_
=
std
::
move
(
op
);
data_
=
std
::
move
(
op
);
}
}
};
};
...
...
src/op/reduce.cc
View file @
bbbf4207
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
#include "../op/parallel.h"
#include "../op/parallel.h"
#include "../target/utils.h"
#include "../target/utils.h"
#include "../transform/loop_partition.h"
#include "../transform/loop_partition.h"
#include "region.h"
#include "tir/transforms/ir_utils.h"
#include "tir/transforms/ir_utils.h"
namespace
tvm
{
namespace
tvm
{
...
@@ -21,10 +22,54 @@ namespace tl {
...
@@ -21,10 +22,54 @@ namespace tl {
using
namespace
tir
;
using
namespace
tir
;
// Normalize an argument (BufferRegion/BufferLoad/tl.region)
// to BufferRegion so Reduce can uniformly consume regions.
static
BufferRegion
NormalizeToBufferRegion
(
const
PrimExpr
&
arg
,
const
BufferMap
&
vmap
)
{
// Case 1: Already a BufferRegion
if
(
arg
->
IsInstance
<
BufferRegionNode
>
())
{
return
Downcast
<
BufferRegion
>
(
arg
);
}
// Case 2: BufferLoad — convert indices to ranges (Ramp -> lanes, else
// extent=1)
if
(
const
auto
*
load
=
arg
.
as
<
BufferLoadNode
>
())
{
Array
<
Range
>
ranges
;
for
(
const
PrimExpr
&
index
:
load
->
indices
)
{
if
(
const
auto
*
ramp
=
index
.
as
<
RampNode
>
())
{
ICHECK
(
ramp
->
stride
.
as
<
IntImmNode
>
())
<<
"Ramp stride must be IntImm"
;
ICHECK_EQ
(
ramp
->
stride
.
as
<
IntImmNode
>
()
->
value
,
1
)
<<
"Only stride-1 Ramp is supported in region conversion"
;
ICHECK
(
ramp
->
lanes
.
as
<
IntImmNode
>
())
<<
"Scalable vector lanes not supported in region conversion"
;
ranges
.
push_back
(
Range
::
FromMinExtent
(
ramp
->
base
,
ramp
->
lanes
));
}
else
{
ranges
.
push_back
(
Range
::
FromMinExtent
(
index
,
1
));
}
}
return
BufferRegion
(
load
->
buffer
,
ranges
);
}
// Case 3: Call nodes (only tl.region)
if
(
const
auto
*
call
=
arg
.
as
<
CallNode
>
())
{
// tl.region(...) — reconstruct via RegionOp
if
(
call
->
op
.
same_as
(
RegionOp
::
Get
()))
{
RegionOp
region
(
call
->
args
,
vmap
);
return
BufferRegion
(
region
->
GetBuffer
(),
region
->
GetRanges
());
}
}
LOG
(
FATAL
)
<<
"Unsupported argument for BufferRegion in reduce: "
<<
arg
;
throw
;
// Unreachable
}
ReduceOp
::
ReduceOp
(
Array
<
PrimExpr
>
args
,
BufferMap
vmap
)
{
ReduceOp
::
ReduceOp
(
Array
<
PrimExpr
>
args
,
BufferMap
vmap
)
{
ObjectPtr
<
ReduceOpNode
>
node
=
make_object
<
ReduceOpNode
>
();
ObjectPtr
<
ReduceOpNode
>
node
=
tvm
::
ffi
::
make_object
<
ReduceOpNode
>
();
node
->
src
=
vmap
[
GetVarFromAccessPtr
(
args
[
0
])];
// Accept BufferRegion/BufferLoad/tl.region for src/dst
node
->
dst
=
vmap
[
GetVarFromAccessPtr
(
args
[
1
])];
node
->
srcRegion_
=
NormalizeToBufferRegion
(
args
[
0
],
vmap
);
node
->
dstRegion_
=
NormalizeToBufferRegion
(
args
[
1
],
vmap
);
node
->
src
=
node
->
srcRegion_
->
buffer
;
node
->
dst
=
node
->
dstRegion_
->
buffer
;
std
::
string
reduce_type
=
args
[
2
].
as
<
StringImm
>
().
value
()
->
value
;
std
::
string
reduce_type
=
args
[
2
].
as
<
StringImm
>
().
value
()
->
value
;
node
->
dim
=
args
[
3
].
as
<
IntImm
>
().
value
()
->
value
;
node
->
dim
=
args
[
3
].
as
<
IntImm
>
().
value
()
->
value
;
node
->
type
=
ReduceType
(
reduce_type
);
node
->
type
=
ReduceType
(
reduce_type
);
...
@@ -33,12 +78,12 @@ ReduceOp::ReduceOp(Array<PrimExpr> args, BufferMap vmap) {
...
@@ -33,12 +78,12 @@ ReduceOp::ReduceOp(Array<PrimExpr> args, BufferMap vmap) {
}
}
TileOperator
ReduceOpNode
::
Clone
()
const
{
TileOperator
ReduceOpNode
::
Clone
()
const
{
auto
op
=
make_object
<
ReduceOpNode
>
(
*
this
);
auto
op
=
tvm
::
ffi
::
make_object
<
ReduceOpNode
>
(
*
this
);
return
ReduceOp
(
op
);
return
ReduceOp
(
op
);
}
}
TileOperator
CumSumOpNode
::
Clone
()
const
{
TileOperator
CumSumOpNode
::
Clone
()
const
{
auto
op
=
make_object
<
CumSumOpNode
>
(
*
this
);
auto
op
=
tvm
::
ffi
::
make_object
<
CumSumOpNode
>
(
*
this
);
return
CumSumOp
(
op
);
return
CumSumOp
(
op
);
}
}
...
@@ -85,6 +130,7 @@ PrimExpr ReduceOpNode::MakeInitValue() const {
...
@@ -85,6 +130,7 @@ PrimExpr ReduceOpNode::MakeInitValue() const {
return
make_zero
(
dst
->
dtype
);
return
make_zero
(
dst
->
dtype
);
}
else
{
}
else
{
LOG
(
FATAL
)
<<
"Unsupported reduce type: "
<<
type
->
type
;
LOG
(
FATAL
)
<<
"Unsupported reduce type: "
<<
type
->
type
;
return
PrimExpr
();
}
}
}
}
...
@@ -103,7 +149,7 @@ PrimExpr ReduceOpNode::MakeReduce(const PrimExpr &lhs,
...
@@ -103,7 +149,7 @@ PrimExpr ReduceOpNode::MakeReduce(const PrimExpr &lhs,
}
else
if
(
type
->
isMin
())
{
}
else
if
(
type
->
isMin
())
{
return
Min
(
lhs
,
rhs
);
return
Min
(
lhs
,
rhs
);
}
else
if
(
type
->
isAbsMax
())
{
}
else
if
(
type
->
isAbsMax
())
{
return
Max
(
Max
(
lhs
,
rhs
),
-
Min
(
lhs
,
rhs
));
return
Max
(
tvm
::
abs
(
lhs
)
,
tvm
::
abs
(
rhs
));
}
else
if
(
type
->
isBitAnd
())
{
}
else
if
(
type
->
isBitAnd
())
{
return
lhs
&
rhs
;
return
lhs
&
rhs
;
}
else
if
(
type
->
isBitOr
())
{
}
else
if
(
type
->
isBitOr
())
{
...
@@ -359,70 +405,6 @@ Stmt ReduceOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
...
@@ -359,70 +405,6 @@ Stmt ReduceOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
return
body
;
return
body
;
}
}
auto
is_shared_scope
=
[](
const
std
::
string
&
scope
)
{
return
scope
==
"shared"
||
scope
==
"shared.dyn"
;
};
if
(
is_shared_scope
(
src_scope
)
&&
is_shared_scope
(
dst_scope
))
{
Buffer
src_buffer
=
get_buffer
(
this
->
src
);
Buffer
dst_buffer
=
get_buffer
(
this
->
dst
);
size_t
src_dim
=
src_buffer
->
shape
.
size
();
size_t
dst_dim
=
dst_buffer
->
shape
.
size
();
bool
is_1d_reduce
=
(
src_dim
==
dst_dim
&&
dst_dim
==
1
);
if
(
!
is_1d_reduce
)
{
ICHECK_EQ
(
src_dim
,
dst_dim
+
1
)
<<
"Reduce dimension mismatch."
;
}
else
{
ICHECK_EQ
(
dst_dim
,
1U
)
<<
"Expect scalar layout for 1D reduce."
;
}
auto
thread_extent
=
as_const_int
(
T
.
thread_bounds
->
extent
);
ICHECK
(
thread_extent
)
<<
"Shared-memory reduce requires static thread extent."
;
int
threads
=
*
thread_extent
;
if
(
TargetIsCuda
(
T
.
target
))
{
ICHECK_EQ
(
threads
%
32
,
0
)
<<
"Shared reduce expects blockDim.x to be a multiple of 32 on CUDA."
;
}
else
if
(
TargetIsRocm
(
T
.
target
))
{
ICHECK_EQ
(
threads
%
64
,
0
)
<<
"Shared reduce expects blockDim.x to be a multiple of 64 on HIP."
;
}
bool
use_abs
=
this
->
type
->
isAbsSum
()
||
this
->
type
->
isAbsMax
();
bool
need_accumulate
=
(
!
this
->
clear
)
&&
(
this
->
type
->
isSum
()
||
this
->
type
->
isAbsSum
()
||
this
->
type
->
isBitAnd
()
||
this
->
type
->
isBitOr
()
||
this
->
type
->
isBitXor
());
PrimExpr
reduce_extent
=
src_buffer
->
shape
[
this
->
dim
];
PrimExpr
tail_extent
=
make_const
(
DataType
::
Int
(
32
),
1
);
for
(
size_t
i
=
this
->
dim
+
1
;
i
<
src_dim
;
++
i
)
{
tail_extent
=
analyzer
->
Simplify
(
tail_extent
*
src_buffer
->
shape
[
i
]);
}
PrimExpr
total_dest
=
make_const
(
DataType
::
Int
(
32
),
1
);
for
(
size_t
i
=
0
;
i
<
dst_dim
;
++
i
)
{
total_dest
=
analyzer
->
Simplify
(
total_dest
*
dst_buffer
->
shape
[
i
]);
}
std
::
stringstream
ss
;
std
::
string
reducer
=
this
->
MakeCodegenReducer
();
ss
<<
"tl::SharedReduceWarp<"
<<
reducer
<<
", "
<<
threads
<<
", "
<<
(
use_abs
?
"true"
:
"false"
)
<<
", "
<<
(
need_accumulate
?
"true"
:
"false"
)
<<
">::run"
;
Array
<
PrimExpr
>
call_args
=
{
StringImm
(
ss
.
str
()),
src_buffer
.
access_ptr
(
1
),
dst_buffer
.
access_ptr
(
3
),
cast
(
DataType
::
Int
(
32
),
total_dest
),
cast
(
DataType
::
Int
(
32
),
reduce_extent
),
cast
(
DataType
::
Int
(
32
),
tail_extent
),
this
->
MakeInitValue
()};
return
Evaluate
(
Call
(
dst_buffer
->
dtype
,
builtin
::
call_extern
(),
call_args
));
}
LOG
(
FATAL
)
<<
"Reduce for buffers in scope ("
<<
src_scope
<<
", "
LOG
(
FATAL
)
<<
"Reduce for buffers in scope ("
<<
src_scope
<<
", "
<<
dst_scope
<<
") is not implemented."
;
<<
dst_scope
<<
") is not implemented."
;
return
Stmt
();
return
Stmt
();
...
@@ -432,6 +414,7 @@ LayoutMap ReduceOpNode::InferLayout(const LayoutInferArgs &T,
...
@@ -432,6 +414,7 @@ LayoutMap ReduceOpNode::InferLayout(const LayoutInferArgs &T,
InferLevel
level
)
const
{
InferLevel
level
)
const
{
if
(
level
>=
InferLevel
::
kStrict
)
if
(
level
>=
InferLevel
::
kStrict
)
return
{};
return
{};
if
(
src
.
scope
()
==
"local.fragment"
&&
dst
.
scope
()
==
"local.fragment"
&&
if
(
src
.
scope
()
==
"local.fragment"
&&
dst
.
scope
()
==
"local.fragment"
&&
T
.
layout_map
.
count
(
src
))
{
T
.
layout_map
.
count
(
src
))
{
auto
src_layout
=
T
.
layout_map
[
src
].
as
<
Fragment
>
().
value
();
auto
src_layout
=
T
.
layout_map
[
src
].
as
<
Fragment
>
().
value
();
...
@@ -452,10 +435,40 @@ LayoutMap ReduceOpNode::InferLayout(const LayoutInferArgs &T,
...
@@ -452,10 +435,40 @@ LayoutMap ReduceOpNode::InferLayout(const LayoutInferArgs &T,
}
}
auto
thd
=
src_layout
->
ForwardThread
(
auto
thd
=
src_layout
->
ForwardThread
(
fwd
,
FloorDiv
(
ReplicationPlaceholder
(),
indice_rep_extent
));
fwd
,
FloorDiv
(
ReplicationPlaceholder
(),
indice_rep_extent
));
// Ensure the thread count is divisible by the replicate extent.
// Otherwise, we cannot infer a valid fragment<->fragment layout.
{
arith
::
Analyzer
analyzer
;
PrimExpr
num_threads
=
T
.
thread_bounds
->
extent
;
// Though the dest_buffer_rep_extent will be compressed at
// CondenseReplicateVar, we need to check the divisibility here to avoid
// the issue that the thread count is not divisible by the replicate
// extent.
if
(
!
analyzer
.
CanProve
(
FloorMod
(
num_threads
,
dest_buffer_rep_extent
)
==
0
)
&&
!
analyzer
.
CanProve
(
FloorMod
(
dest_buffer_rep_extent
,
num_threads
)
==
0
))
{
ICHECK
(
false
)
<<
"ReduceOp fragment layout inference failed: "
"num_threads % replicate_extent != 0. "
<<
"This mapping requires the block's thread count to be "
"divisible by the "
<<
"replicate extent. "
<<
"Try one of: (1) choose a thread block size divisible "
"by replicate_extent; "
<<
"(2) pick a different reduce dimension or adjust the "
"source fragment layout; "
<<
"Details: num_threads="
<<
num_threads
<<
", replicate_extent="
<<
indice_rep_extent
<<
", src="
<<
src
<<
", dst="
<<
dst
;
}
}
Fragment
dst_layout
=
Fragment
dst_layout
=
Fragment
(
dst
->
shape
,
{},
thd
,
dest_buffer_rep_extent
,
std
::
nullopt
)
Fragment
(
dst
->
shape
,
{},
thd
,
dest_buffer_rep_extent
,
std
::
nullopt
)
->
CondenseReplicateVar
()
->
CondenseReplicateVar
()
->
BindThreadRange
(
T
.
thread_bounds
);
->
BindThreadRange
(
T
.
thread_bounds
);
if
(
!
T
.
layout_map
.
count
(
dst
))
if
(
!
T
.
layout_map
.
count
(
dst
))
return
{{
dst
,
dst_layout
}};
return
{{
dst
,
dst_layout
}};
else
{
else
{
...
@@ -512,7 +525,7 @@ CumSumOp::CumSumOp(Array<PrimExpr> args, BufferMap vmap) {
...
@@ -512,7 +525,7 @@ CumSumOp::CumSumOp(Array<PrimExpr> args, BufferMap vmap) {
/// - dim: dimension to cumsum
/// - dim: dimension to cumsum
/// - reverse: whether to cumsum in reverse order
/// - reverse: whether to cumsum in reverse order
CHECK_EQ
(
args
.
size
(),
4
);
CHECK_EQ
(
args
.
size
(),
4
);
ObjectPtr
<
CumSumOpNode
>
node
=
make_object
<
CumSumOpNode
>
();
ObjectPtr
<
CumSumOpNode
>
node
=
tvm
::
ffi
::
make_object
<
CumSumOpNode
>
();
node
->
src
=
vmap
[
GetVarFromAccessPtr
(
args
[
0
])];
node
->
src
=
vmap
[
GetVarFromAccessPtr
(
args
[
0
])];
node
->
dst
=
vmap
[
GetVarFromAccessPtr
(
args
[
1
])];
node
->
dst
=
vmap
[
GetVarFromAccessPtr
(
args
[
1
])];
node
->
dim
=
args
[
2
].
as
<
IntImm
>
().
value
()
->
value
;
node
->
dim
=
args
[
2
].
as
<
IntImm
>
().
value
()
->
value
;
...
@@ -567,5 +580,12 @@ TIR_REGISTER_TL_OP(CumSumOp, cumsum)
...
@@ -567,5 +580,12 @@ TIR_REGISTER_TL_OP(CumSumOp, cumsum)
.
set_num_inputs
(
4
)
.
set_num_inputs
(
4
)
.
set_attr
<
TCallEffectKind
>
(
"TCallEffectKind"
,
.
set_attr
<
TCallEffectKind
>
(
"TCallEffectKind"
,
Integer
(
CallEffectKind
::
kOpaque
));
Integer
(
CallEffectKind
::
kOpaque
));
TVM_FFI_STATIC_INIT_BLOCK
()
{
ReduceOpNode
::
RegisterReflection
();
CumSumOpNode
::
RegisterReflection
();
ReduceTypeNode
::
RegisterReflection
();
}
}
// namespace tl
}
// namespace tl
}
// namespace tvm
}
// namespace tvm
Prev
1
2
3
4
5
6
7
8
9
…
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment