Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
bbbf4207
Unverified
Commit
bbbf4207
authored
Nov 14, 2025
by
guchaoyang
Committed by
GitHub
Nov 14, 2025
Browse files
Merge branch 'main' into dcu
parents
8f4628e0
5eb30a4f
Changes
286
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1075 additions
and
906 deletions
+1075
-906
src/op/builtin.h
src/op/builtin.h
+64
-6
src/op/copy.cc
src/op/copy.cc
+15
-9
src/op/copy.h
src/op/copy.h
+7
-45
src/op/fill.cc
src/op/fill.cc
+44
-12
src/op/fill.h
src/op/fill.h
+3
-17
src/op/finalize_reducer.cc
src/op/finalize_reducer.cc
+3
-3
src/op/finalize_reducer.h
src/op/finalize_reducer.h
+5
-17
src/op/gemm.cc
src/op/gemm.cc
+341
-306
src/op/gemm.h
src/op/gemm.h
+47
-103
src/op/gemm_py.cc
src/op/gemm_py.cc
+234
-80
src/op/gemm_py.h
src/op/gemm_py.h
+42
-75
src/op/gemm_sp.cc
src/op/gemm_sp.cc
+80
-80
src/op/gemm_sp.h
src/op/gemm_sp.h
+30
-54
src/op/logical.cc
src/op/logical.cc
+3
-1
src/op/math.cc
src/op/math.cc
+28
-0
src/op/operator.cc
src/op/operator.cc
+2
-2
src/op/operator.h
src/op/operator.h
+3
-5
src/op/parallel.cc
src/op/parallel.cc
+28
-2
src/op/parallel.h
src/op/parallel.h
+5
-18
src/op/reduce.cc
src/op/reduce.cc
+91
-71
No files found.
src/op/builtin.h
View file @
bbbf4207
...
...
@@ -241,14 +241,24 @@ TVM_DLL const Op &ptx_wgmma_ss();
/*!
* \brief tvm intrinsics for ptx tensor core wgmma instructions.
*
* void ptx_wgmma_rs(StringImm accum_dtype, StringImm wgmma_prefix,
bool
*
a_is_k_major,
bool b_is_k_major, StringImm a_dtype_abbrv, StringImm
*
b_dtype_abbrv,
StringImm accum_dtype_abbrv, Var A_descriptor, PrimExpr
*
A_offset, Var
B_descriptor, Var B_offset, Var C_data, Var C_offset, bool
*
scale_out,
bool scale_in_a, bool scale_in_b);
* void ptx_wgmma_rs(StringImm accum_dtype, StringImm wgmma_prefix,
* bool b_is_k_major, StringImm a_dtype_abbrv, StringImm
b_dtype_abbrv,
* StringImm accum_dtype_abbrv, Var A_descriptor, PrimExpr
A_offset, Var
* B_descriptor, Var B_offset, Var C_data, Var C_offset, bool
scale_out,
* bool scale_in_a, bool scale_in_b);
*/
TVM_DLL
const
Op
&
ptx_wgmma_rs
();
/*!
* \brief tvm intrinsic for tcgen05 mma shared-shared instructions.
*/
TVM_DLL
const
Op
&
ptx_tcgen05_mma_ss
();
/*!
* \brief tvm intrinsic for tcgen05 mma tensor-shared instructions.
*/
TVM_DLL
const
Op
&
ptx_tcgen05_mma_ts
();
/*!
* \brief tvm intrinsics for initializing tensor memory
*
...
...
@@ -265,6 +275,17 @@ TVM_DLL const Op &ptx_init_tensor_memory();
*/
TVM_DLL
const
Op
&
ptx_deallocate_tensor_memory
();
/*!
* \brief tvm intrinsic for ptx tensor core mma instructions on SM70.
*
* void ptx_mma_sm70(StringImm shape, StringImm A_layout, StringImm B_layout,
* StringImm A_dtype, StringImm B_dtype, StringImm C_dtype,
* Var multiplicand_a, Expr a_index,
* Var multiplicand_b, Expr b_index,
* Var accumulator, Expr c_index, bool saturate);
*/
TVM_DLL
const
Op
&
ptx_mma_sm70
();
/*!
* \brief tvm intrinsics for ldmatrix
*
...
...
@@ -361,6 +382,14 @@ TVM_DLL const Op &warpgroup_commit_batch();
*/
TVM_DLL
const
Op
&
warpgroup_wait
();
/*!
* \brief Fence accumulator operand registers for upcoming WGMMA operations
*
* warpgroup_fence_operand(dtype, ptr, offset, num_regs)
*
*/
TVM_DLL
const
Op
&
warpgroup_fence_operand
();
/*!
* \brief Return the canonical lane index for the calling thread.
*
...
...
@@ -494,7 +523,21 @@ TVM_DLL const Op &tl_shuffle_elect();
* This op is used to represent a descriptor initialization operation in
* tilelang.
*/
TVM_DLL
const
Op
&
initialize_descriptor
();
TVM_DLL
const
Op
&
initialize_wgmma_descriptor
();
/*!
* \brief tilelang intrinsic for initializing a descriptor buffer for
* tcgen05 mma.
*/
TVM_DLL
const
Op
&
initialize_tcgen05_descriptor
();
/*!
* \brief tilelang intrinsic for committing UMMA (TCGEN05) barrier arrive.
*
* This op wraps the device-side arrive used to signal completion of MMA work
* to a shared-memory mbarrier. It mirrors CUTLASS's umma_arrive.
*/
TVM_DLL
const
Op
&
tcgen05_mma_arrive
();
/*!
* \brief tilelang intrinsic for setting the start address of a descriptor
...
...
@@ -505,6 +548,7 @@ TVM_DLL const Op &initialize_descriptor();
*/
TVM_DLL
const
Op
&
increase_descriptor_offset
();
/*!
* \brief tilelang intrinsic for element-wise atomic addition.
*
...
...
@@ -513,6 +557,20 @@ TVM_DLL const Op &increase_descriptor_offset();
*/
TVM_DLL
const
Op
&
atomicadd_elem_op
();
/*!
* \brief tilelang intrinsic for assert on device.
*
* This op is used to represent an assert on device
*/
TVM_DLL
const
Op
&
device_assert
();
/*!
* \brief tilelang intrinsic for assert on device with additional message.
*
* This op is used to represent an assert on device with additional message.
*/
TVM_DLL
const
Op
&
device_assert_with_msg
();
}
// namespace tl
}
// namespace tvm
...
...
src/op/copy.cc
View file @
bbbf4207
...
...
@@ -130,7 +130,7 @@ template <typename T> static Array<T> ReverseArray(Array<T> array) {
* @param vmap BufferMap used to resolve RegionOp buffers and ranges.
*/
Copy
::
Copy
(
Array
<
PrimExpr
>
args
,
BufferMap
vmap
)
{
ObjectPtr
<
CopyNode
>
node
=
make_object
<
CopyNode
>
();
ObjectPtr
<
CopyNode
>
node
=
tvm
::
ffi
::
make_object
<
CopyNode
>
();
Array
<
Range
>
rgs
[
2
];
Buffer
bf
[
2
];
for
(
int
i
=
0
;
i
<
2
;
i
++
)
{
...
...
@@ -169,7 +169,7 @@ Copy::Copy(Array<PrimExpr> args, BufferMap vmap) {
* @return TileOperator A TileOperator owning the cloned CopyNode.
*/
TileOperator
CopyNode
::
Clone
()
const
{
auto
op
=
make_object
<
CopyNode
>
(
*
this
);
auto
op
=
tvm
::
ffi
::
make_object
<
CopyNode
>
(
*
this
);
if
(
par_op_
.
defined
())
{
op
->
par_op_
=
Downcast
<
ParallelOp
>
(
par_op_
->
Clone
());
}
...
...
@@ -401,7 +401,7 @@ LayoutMap CopyNode::InferLayout(const LayoutInferArgs &T,
using
namespace
tvm
::
transform
;
PassContext
pass_ctx
=
PassContext
::
Current
();
bool
disable_tma_lower
=
pass_ctx
->
GetConfig
<
b
ool
>
(
kDisableTMALower
,
false
).
value
();
pass_ctx
->
GetConfig
<
B
ool
>
(
kDisableTMALower
,
Bool
(
false
)
)
.
value
();
auto
copy_inst
=
GetCopyInst
(
target
,
disable_tma_lower
||
disable_tma
,
T
.
layout_map
,
T
.
analyzer
,
T
.
buffer_oob
);
...
...
@@ -793,7 +793,7 @@ Stmt CopyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
using
namespace
tvm
::
transform
;
PassContext
pass_ctx
=
PassContext
::
Current
();
bool
disable_tma_lower
=
pass_ctx
->
GetConfig
<
b
ool
>
(
kDisableTMALower
,
false
).
value
();
pass_ctx
->
GetConfig
<
B
ool
>
(
kDisableTMALower
,
Bool
(
false
)
)
.
value
();
auto
copy_inst
=
GetCopyInst
(
target
,
disable_tma_lower
||
disable_tma
,
T
.
layout_map
,
analyzer
);
if
(
copy_inst
==
CopyInst
::
kTMemLoad
||
copy_inst
==
CopyInst
::
kTMemStore
)
{
...
...
@@ -1504,7 +1504,12 @@ Stmt CopyNode::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer,
}
auto
inner_box_dim
=
as_const_int
(
desc
.
smem_box
[
0
]);
ICHECK
(
inner_box_dim
!=
nullptr
);
if
(
inner_box_dim
==
nullptr
)
{
LOG
(
WARNING
)
<<
"inner_box_dim "
<<
desc
.
smem_box
[
0
]
<<
" can only be a constant integer for TMA bulk copy, "
"fallback to normal copy"
;
return
LowerNormalCopy
(
T
,
analyzer
);
}
int
instruction_dim
=
*
inner_box_dim
;
if
(
desc
.
swizzle
==
static_cast
<
int
>
(
CU_TENSOR_MAP_SWIZZLE_64B
))
{
instruction_dim
=
64
/
src
->
dtype
.
bytes
();
...
...
@@ -1722,7 +1727,8 @@ Array<PrimExpr> TMADesc::EncodeCallArgs() const {
* @param vmap Mapping from original buffer variables to actual Buffer objects.
*/
Conv2DIm2ColOp
::
Conv2DIm2ColOp
(
Array
<
PrimExpr
>
args
,
BufferMap
vmap
)
{
ObjectPtr
<
Conv2DIm2ColOpNode
>
node
=
make_object
<
Conv2DIm2ColOpNode
>
();
ObjectPtr
<
Conv2DIm2ColOpNode
>
node
=
tvm
::
ffi
::
make_object
<
Conv2DIm2ColOpNode
>
();
node
->
src
=
vmap
[
GetVarFromAccessPtr
(
args
[
0
])];
node
->
dst
=
vmap
[
GetVarFromAccessPtr
(
args
[
1
])];
node
->
nhw_step
=
args
[
2
];
...
...
@@ -1747,7 +1753,7 @@ Conv2DIm2ColOp::Conv2DIm2ColOp(Array<PrimExpr> args, BufferMap vmap) {
* @return TileOperator A TileOperator containing the cloned Conv2DIm2ColOpNode.
*/
TileOperator
Conv2DIm2ColOpNode
::
Clone
()
const
{
auto
op
=
make_object
<
Conv2DIm2ColOpNode
>
(
*
this
);
auto
op
=
tvm
::
ffi
::
make_object
<
Conv2DIm2ColOpNode
>
(
*
this
);
return
Conv2DIm2ColOp
(
op
);
}
...
...
@@ -1973,9 +1979,9 @@ TIR_REGISTER_TL_OP(Conv2DIm2ColOp, c2d_im2col)
.
set_attr
<
TCallEffectKind
>
(
"TCallEffectKind"
,
Integer
(
CallEffectKind
::
kOpaque
));
TVM_FFI_STATIC_INIT_BLOCK
({
TVM_FFI_STATIC_INIT_BLOCK
(
)
{
CopyNode
::
RegisterReflection
();
Conv2DIm2ColOpNode
::
RegisterReflection
();
}
);
}
}
// namespace tl
}
// namespace tvm
src/op/copy.h
View file @
bbbf4207
...
...
@@ -101,8 +101,7 @@ public:
};
uint8_t
eviction_policy
;
// Policy for cache eviction
static
constexpr
const
char
*
_type_key
=
"tl.Copy"
;
TVM_DECLARE_FINAL_OBJECT_INFO
(
CopyNode
,
TileOperatorNode
);
TVM_FFI_DECLARE_OBJECT_INFO_FINAL
(
"tl.Copy"
,
CopyNode
,
TileOperatorNode
);
static
void
RegisterReflection
()
{
namespace
refl
=
tvm
::
ffi
::
reflection
;
...
...
@@ -114,23 +113,6 @@ public:
.
def_ro
(
"coalesced_width"
,
&
CopyNode
::
coalesced_width
);
}
bool
SEqualReduce
(
const
CopyNode
*
other
,
SEqualReducer
equal
)
const
{
return
equal
(
src
,
other
->
src
)
&&
equal
(
dst
,
other
->
dst
)
&&
equal
(
src_range
,
other
->
src_range
)
&&
equal
(
dst_range
,
other
->
dst_range
)
&&
equal
(
coalesced_width
,
other
->
coalesced_width
);
}
void
SHashReduce
(
SHashReducer
hash_reduce
)
const
{
hash_reduce
(
src
);
hash_reduce
(
dst
);
hash_reduce
(
src_range
);
hash_reduce
(
dst_range
);
hash_reduce
(
coalesced_width
);
}
static
constexpr
bool
_type_has_method_sequal_reduce
=
true
;
static
constexpr
bool
_type_has_method_shash_reduce
=
true
;
/*!
* \brief Lower the copy operator to a TIR statement.
* \param T Arguments for lowering.
...
...
@@ -291,7 +273,7 @@ protected:
class
Copy
:
public
TileOperator
{
public:
TVM_DEFINE_OBJECT_REF_METHODS
(
Copy
,
TileOperator
,
CopyNode
);
TVM_
FFI_
DEFINE_OBJECT_REF_METHODS
_NULLABLE
(
Copy
,
TileOperator
,
CopyNode
);
/*!
* \brief Constructor.
...
...
@@ -323,8 +305,8 @@ public:
PrimExpr
nhw_step
;
// Step size in NHW dimensions
PrimExpr
c_step
;
// Step size in channel dimension
static
constexpr
const
char
*
_type_key
=
"tl.Conv2DIm2Col"
;
TVM_DECLARE_FINAL_OBJECT_INFO
(
Conv2DIm2ColOpNode
,
TileOperatorNode
);
TVM_FFI_DECLARE_OBJECT_INFO_FINAL
(
"tl.Conv2DIm2Col"
,
Conv2DIm2ColOpNode
,
TileOperatorNode
);
static
void
RegisterReflection
()
{
namespace
refl
=
tvm
::
ffi
::
reflection
;
...
...
@@ -338,26 +320,6 @@ public:
.
def_ro
(
"eviction_policy"
,
&
Conv2DIm2ColOpNode
::
eviction_policy
);
}
bool
SEqualReduce
(
const
Conv2DIm2ColOpNode
*
other
,
SEqualReducer
equal
)
const
{
return
equal
(
src
,
other
->
src
)
&&
equal
(
dst
,
other
->
dst
)
&&
equal
(
stride
,
other
->
stride
)
&&
equal
(
padding
,
other
->
padding
)
&&
equal
(
dilation
,
other
->
dilation
)
&&
equal
(
kernel
,
other
->
kernel
)
&&
equal
(
eviction_policy
,
other
->
eviction_policy
);
}
void
SHashReduce
(
SHashReducer
hash_reduce
)
const
{
hash_reduce
(
src
);
hash_reduce
(
dst
);
hash_reduce
(
stride
);
hash_reduce
(
padding
);
hash_reduce
(
dilation
);
hash_reduce
(
kernel
);
hash_reduce
(
eviction_policy
);
}
static
constexpr
bool
_type_has_method_sequal_reduce
=
true
;
static
constexpr
bool
_type_has_method_shash_reduce
=
true
;
/*!
* \brief Lower to TIR statement.
*/
...
...
@@ -378,7 +340,7 @@ public:
class
Conv2DIm2ColOp
:
public
TileOperator
{
public:
TVM_DEFINE_OBJECT_REF_METHODS
(
Conv2DIm2ColOp
,
TileOperator
,
TVM_
FFI_
DEFINE_OBJECT_REF_METHODS
_NULLABLE
(
Conv2DIm2ColOp
,
TileOperator
,
Conv2DIm2ColOpNode
);
TVM_DLL
Conv2DIm2ColOp
(
Array
<
PrimExpr
>
args
,
BufferMap
vmap
);
static
const
Op
&
Get
();
...
...
src/op/fill.cc
View file @
bbbf4207
...
...
@@ -17,6 +17,7 @@
#include "../transform/loop_partition.h"
#include "../transform/loop_vectorize.h"
#include "builtin.h"
#include "region.h"
namespace
tvm
{
namespace
tl
{
...
...
@@ -60,9 +61,32 @@ using namespace tir;
* of bounds.
*/
Fill
::
Fill
(
Array
<
PrimExpr
>
args
,
BufferMap
vmap
)
{
ObjectPtr
<
FillNode
>
node
=
make_object
<
FillNode
>
();
ObjectPtr
<
FillNode
>
node
=
tvm
::
ffi
::
make_object
<
FillNode
>
();
if
(
args
[
0
]
->
IsInstance
<
BufferLoadNode
>
())
{
// Case 1: Region descriptor call (tl.region)
if
(
const
auto
*
call
=
args
[
0
].
as
<
CallNode
>
())
{
if
(
call
->
op
.
same_as
(
RegionOp
::
Get
()))
{
auto
region
=
RegionOp
(
call
->
args
,
vmap
);
node
->
dst
=
region
->
GetBuffer
();
node
->
region
=
region
->
GetRanges
();
}
else
if
(
call
->
op
.
same_as
(
builtin
::
tvm_access_ptr
()))
{
node
->
dst
=
vmap
[
GetVarFromAccessPtr
(
args
[
0
])];
for
(
int
i
=
0
;
i
<
node
->
dst
->
shape
.
size
();
i
++
)
{
node
->
region
.
push_back
(
Range
(
0
,
node
->
dst
->
shape
[
i
]));
}
}
else
{
ICHECK
(
false
)
<<
"Unsupported call op in tl.fill: "
<<
Downcast
<
Op
>
(
call
->
op
)
->
name
;
}
// Case 2: Explicit BufferRegion (legacy path)
}
else
if
(
args
[
0
]
->
IsInstance
<
BufferRegionNode
>
())
{
auto
region
=
Downcast
<
BufferRegion
>
(
args
[
0
]);
node
->
dst
=
region
->
buffer
;
node
->
region
=
region
->
region
;
// Case 3: Vector/scalar region expressed via BufferLoad indices
}
else
if
(
args
[
0
]
->
IsInstance
<
BufferLoadNode
>
())
{
auto
buffer_load
=
Downcast
<
BufferLoad
>
(
args
[
0
]);
for
(
const
auto
&
index
:
buffer_load
->
indices
)
{
if
(
const
auto
*
ramp
=
index
.
as
<
RampNode
>
())
{
...
...
@@ -77,6 +101,7 @@ Fill::Fill(Array<PrimExpr> args, BufferMap vmap) {
}
}
node
->
dst
=
buffer_load
->
buffer
;
// Case 4: Access pointer, fill the full buffer
}
else
{
node
->
dst
=
vmap
[
GetVarFromAccessPtr
(
args
[
0
])];
for
(
int
i
=
0
;
i
<
node
->
dst
->
shape
.
size
();
i
++
)
{
...
...
@@ -95,14 +120,19 @@ Fill::Fill(Array<PrimExpr> args, BufferMap vmap) {
<<
" != "
<<
node
->
dst
->
shape
.
size
();
for
(
int
i
=
0
;
i
<
node
->
region
.
size
();
i
++
)
{
// bound check if region is static
if
(
node
->
region
[
i
]
->
min
.
as
<
IntImm
>
())
{
int64_t
min
=
Downcast
<
IntImm
>
(
node
->
region
[
i
]
->
min
)
->
value
;
if
(
const
auto
*
min_imm
=
node
->
region
[
i
]
->
min
.
as
<
IntImm
Node
>
())
{
int64_t
min
=
min_imm
->
value
;
ICHECK_GE
(
min
,
0
)
<<
"region["
<<
i
<<
"] = "
<<
min
<<
" < 0"
;
}
if
(
node
->
region
[
i
]
->
extent
.
as
<
IntImm
>
())
{
int64_t
extent
=
Downcast
<
IntImm
>
(
node
->
region
[
i
]
->
extent
)
->
value
;
ICHECK_LE
(
extent
,
Downcast
<
IntImm
>
(
node
->
dst
->
shape
[
i
])
->
value
)
<<
"region["
<<
i
<<
"] = "
<<
extent
<<
" > "
<<
node
->
dst
->
shape
[
i
];
if
(
const
auto
*
extent_imm
=
node
->
region
[
i
]
->
extent
.
as
<
IntImmNode
>
())
{
// Only perform the upper-bound check when the destination shape
// extent is also statically known. If the shape is symbolic (e.g., Var),
// skip this static check to avoid invalid downcasts.
if
(
const
auto
*
shape_imm
=
node
->
dst
->
shape
[
i
].
as
<
IntImmNode
>
())
{
ICHECK_LE
(
extent_imm
->
value
,
shape_imm
->
value
)
<<
"region["
<<
i
<<
"] = "
<<
extent_imm
->
value
<<
" > "
<<
node
->
dst
->
shape
[
i
];
}
}
}
data_
=
std
::
move
(
node
);
...
...
@@ -117,7 +147,7 @@ Fill::Fill(Array<PrimExpr> args, BufferMap vmap) {
* @return TileOperator A TileOperator that owns the copied FillNode.
*/
TileOperator
FillNode
::
Clone
()
const
{
auto
op
=
make_object
<
FillNode
>
(
*
this
);
auto
op
=
tvm
::
ffi
::
make_object
<
FillNode
>
(
*
this
);
return
Fill
(
op
);
}
...
...
@@ -140,7 +170,8 @@ For FillNode::MakeSIMTLoop(arith::Analyzer *analyzer) const {
for
(
int
i
=
0
;
i
<
ndim
;
i
++
)
{
Var
var
=
Var
(
std
::
string
{
char
(
'i'
+
i
)},
region
[
i
]
->
extent
->
dtype
);
loop_vars
.
push_back
({
region
[
i
],
var
,
IterVarType
::
kDataPar
});
dst_indices
.
push_back
(
var
);
// Offset the loop induction variable by region min to honor sliced regions
dst_indices
.
push_back
(
region
[
i
]
->
min
+
var
);
}
Stmt
body
=
BufferStore
(
dst
,
value
,
dst_indices
);
for
(
int
i
=
ndim
-
1
;
i
>=
0
;
i
--
)
{
...
...
@@ -202,6 +233,7 @@ Stmt FillNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
return
vectorized_thread_loop
;
}
else
{
LOG
(
FATAL
)
<<
"Unsupported scope "
<<
dst
.
scope
();
return
Stmt
();
}
}
...
...
@@ -226,7 +258,7 @@ TIR_REGISTER_TL_OP(Fill, fill)
.
set_attr
<
TCallEffectKind
>
(
"TCallEffectKind"
,
Integer
(
CallEffectKind
::
kOpaque
));
TVM_FFI_STATIC_INIT_BLOCK
({
FillNode
::
RegisterReflection
();
}
);
TVM_FFI_STATIC_INIT_BLOCK
(
)
{
FillNode
::
RegisterReflection
();
}
}
// namespace tl
}
// namespace tvm
src/op/fill.h
View file @
bbbf4207
...
...
@@ -20,8 +20,7 @@ public:
tir
::
Buffer
dst
;
///< Destination buffer to fill
PrimExpr
value
;
///< Value to fill with
Array
<
Range
>
region
;
///< Region to fill within the buffer
static
constexpr
const
char
*
_type_key
=
"tl.Fill"
;
TVM_DECLARE_FINAL_OBJECT_INFO
(
FillNode
,
TileOperatorNode
);
TVM_FFI_DECLARE_OBJECT_INFO_FINAL
(
"tl.Fill"
,
FillNode
,
TileOperatorNode
);
Stmt
Lower
(
const
LowerArgs
&
T
,
arith
::
Analyzer
*
analyzer
)
const
;
LayoutMap
InferLayout
(
const
LayoutInferArgs
&
T
,
InferLevel
level
)
const
;
...
...
@@ -35,19 +34,6 @@ public:
.
def_ro
(
"region"
,
&
FillNode
::
region
);
}
bool
SEqualReduce
(
const
FillNode
*
other
,
SEqualReducer
equal
)
const
{
return
equal
(
dst
,
other
->
dst
)
&&
equal
(
value
,
other
->
value
)
&&
equal
(
region
,
other
->
region
);
}
void
SHashReduce
(
SHashReducer
hash_reduce
)
const
{
hash_reduce
(
dst
);
hash_reduce
(
value
);
hash_reduce
(
region
);
}
static
constexpr
bool
_type_has_method_sequal_reduce
=
true
;
static
constexpr
bool
_type_has_method_shash_reduce
=
true
;
TileOperator
Clone
()
const
;
private:
...
...
@@ -58,7 +44,7 @@ private:
/// Wrapper class for fill operations
class
Fill
:
public
TileOperator
{
public:
TVM_DEFINE_OBJECT_REF_METHODS
(
Fill
,
TileOperator
,
FillNode
);
TVM_
FFI_
DEFINE_OBJECT_REF_METHODS
_NULLABLE
(
Fill
,
TileOperator
,
FillNode
);
TVM_DLL
Fill
(
Array
<
PrimExpr
>
args
,
BufferMap
vmap
);
static
const
Op
&
Get
();
};
...
...
src/op/finalize_reducer.cc
View file @
bbbf4207
...
...
@@ -33,7 +33,7 @@ using namespace tir;
* Buffer.
*/
FinalizeReducerOp
::
FinalizeReducerOp
(
Array
<
PrimExpr
>
args
,
BufferMap
vmap
)
{
auto
node
=
make_object
<
FinalizeReducerOpNode
>
();
auto
node
=
tvm
::
ffi
::
make_object
<
FinalizeReducerOpNode
>
();
node
->
reducer
=
vmap
[
GetVarFromAccessPtr
(
args
[
0
])];
node
->
op
=
(
ReducerOpType
)
*
as_const_int
(
args
[
1
]);
data_
=
std
::
move
(
node
);
...
...
@@ -152,7 +152,7 @@ LayoutMap FinalizeReducerOpNode::InferLayout(const LayoutInferArgs &T,
* @return TileOperator A TileOperator that contains a deep copy of this node.
*/
TileOperator
FinalizeReducerOpNode
::
Clone
()
const
{
auto
node
=
make_object
<
FinalizeReducerOpNode
>
(
*
this
);
auto
node
=
tvm
::
ffi
::
make_object
<
FinalizeReducerOpNode
>
(
*
this
);
return
TileOperator
(
node
);
}
...
...
@@ -161,6 +161,6 @@ TIR_REGISTER_TL_OP(FinalizeReducerOp, finalize_reducer)
.
set_attr
<
TCallEffectKind
>
(
"TCallEffectKind"
,
Integer
(
CallEffectKind
::
kOpaque
));
TVM_FFI_STATIC_INIT_BLOCK
({
FinalizeReducerOpNode
::
RegisterReflection
();
}
);
TVM_FFI_STATIC_INIT_BLOCK
(
)
{
FinalizeReducerOpNode
::
RegisterReflection
();
}
}
// namespace tl
}
// namespace tvm
src/op/finalize_reducer.h
View file @
bbbf4207
...
...
@@ -27,8 +27,8 @@ public:
tir
::
Buffer
reducer
;
ReducerOpType
op
;
static
constexpr
const
char
*
_type_key
=
"tl.FinalizeReducerOp"
;
TVM_DECLARE_FINAL_OBJECT_INFO
(
FinalizeReducerOpNode
,
TileOperatorNode
);
TVM_FFI_DECLARE_OBJECT_INFO_FINAL
(
"tl.FinalizeReducerOp"
,
FinalizeReducerOpNode
,
TileOperatorNode
);
static
void
RegisterReflection
()
{
namespace
refl
=
tvm
::
ffi
::
reflection
;
...
...
@@ -37,18 +37,6 @@ public:
.
def_ro
(
"op"
,
&
FinalizeReducerOpNode
::
op
);
}
bool
SEqualReduce
(
const
FinalizeReducerOpNode
*
other
,
SEqualReducer
equal
)
const
{
return
equal
(
reducer
,
other
->
reducer
)
&&
equal
(
op
,
other
->
op
);
}
void
SHashReduce
(
SHashReducer
hash_reduce
)
const
{
hash_reduce
(
reducer
);
hash_reduce
(
op
);
}
static
constexpr
bool
_type_has_method_sequal_reduce
=
true
;
static
constexpr
bool
_type_has_method_shash_reduce
=
true
;
Stmt
Lower
(
const
LowerArgs
&
T
,
arith
::
Analyzer
*
analyzer
)
const
override
;
LayoutMap
InferLayout
(
const
LayoutInferArgs
&
T
,
InferLevel
level
)
const
override
;
...
...
@@ -58,7 +46,7 @@ public:
class
FinalizeReducerOp
:
public
TileOperator
{
public:
TVM_DEFINE_OBJECT_REF_METHODS
(
FinalizeReducerOp
,
TileOperator
,
TVM_
FFI_
DEFINE_OBJECT_REF_METHODS
_NULLABLE
(
FinalizeReducerOp
,
TileOperator
,
FinalizeReducerOpNode
);
TVM_DLL
FinalizeReducerOp
(
Array
<
PrimExpr
>
args
,
BufferMap
vmap
);
static
const
Op
&
Get
();
...
...
src/op/gemm.cc
View file @
bbbf4207
...
...
@@ -12,77 +12,14 @@
#include <tvm/tir/transform.h>
#include "../target/utils.h"
#include "region.h"
#include "tcgen5_meta.h"
namespace
tvm
{
namespace
tl
{
using
namespace
tir
;
struct
TCGEN5MMAMeta
{
int
atom_m
,
atom_n
,
atom_k
;
};
// Return {is_success, meta}
static
inline
std
::
pair
<
bool
,
TCGEN5MMAMeta
>
GetTCGEN5MMAMeta
(
int
M
,
int
N
,
int
K
,
DataType
ab_dtype
,
DataType
c_dtype
)
{
// TODO (lei) Currently not all shapes / dtypes are supported for TCGEN5MMA.
#define FAIL \
return { false, TCGEN5MMAMeta{0, 0, 0} }
#define SUCCESS(atom_m, atom_n, atom_k) \
return { \
true, TCGEN5MMAMeta { atom_m, atom_n, atom_k } \
}
std
::
vector
<
int
>
ws_valid_atom_ns
=
{
256
,
128
,
64
};
if
((
ab_dtype
.
is_bfloat16
()
||
ab_dtype
.
is_float16
())
&&
(
c_dtype
.
is_float
()
&&
c_dtype
.
bits
()
==
32
))
{
if
(
K
%
16
!=
0
)
FAIL
;
if
(
M
%
128
==
0
)
{
for
(
int
atom_n
=
256
;
atom_n
>=
16
;
atom_n
-=
16
)
if
(
N
%
atom_n
==
0
)
SUCCESS
(
128
,
atom_n
,
16
);
FAIL
;
}
else
if
(
M
%
64
==
0
)
{
for
(
int
atom_n
:
ws_valid_atom_ns
)
if
(
N
%
atom_n
==
0
)
SUCCESS
(
64
,
atom_n
,
16
);
FAIL
;
}
else
if
(
M
%
32
==
0
)
{
for
(
int
atom_n
:
ws_valid_atom_ns
)
if
(
N
%
atom_n
==
0
)
SUCCESS
(
32
,
atom_n
,
16
);
FAIL
;
}
else
{
FAIL
;
}
}
else
if
((
ab_dtype
.
is_float8_e4m3fn
()
||
ab_dtype
.
is_float8_e5m2
())
&&
(
c_dtype
.
is_float
()
&&
c_dtype
.
bits
()
==
32
))
{
if
(
K
%
32
!=
0
)
FAIL
;
if
(
M
%
128
==
0
)
{
for
(
int
atom_n
=
256
;
atom_n
>=
16
;
atom_n
-=
16
)
if
(
N
%
atom_n
==
0
)
SUCCESS
(
128
,
atom_n
,
32
);
FAIL
;
}
else
if
(
M
%
64
==
0
)
{
for
(
int
atom_n
:
ws_valid_atom_ns
)
if
(
N
%
atom_n
==
0
)
SUCCESS
(
64
,
atom_n
,
32
);
FAIL
;
}
else
if
(
M
%
32
==
0
)
{
for
(
int
atom_n
:
ws_valid_atom_ns
)
if
(
N
%
atom_n
==
0
)
SUCCESS
(
32
,
atom_n
,
32
);
FAIL
;
}
else
{
FAIL
;
}
}
FAIL
;
#undef FAIL
#undef SUCCESS
}
/**
* @brief Construct a Gemm operator from serialized TL arguments and a buffer
* map.
...
...
@@ -111,42 +48,130 @@ GetTCGEN5MMAMeta(int M, int N, int K, DataType ab_dtype, DataType c_dtype) {
* fails with an ICHECK (runtime assertion). No other validation is
* performed here.
*/
// Normalize a GEMM argument (BufferRegion/BufferLoad/tvm_access_ptr/tl.region)
// to BufferRegion
static
BufferRegion
NormalizeToBufferRegion
(
const
PrimExpr
&
arg
,
const
BufferMap
&
vmap
)
{
// Case 1: Already a BufferRegion
if
(
arg
->
IsInstance
<
BufferRegionNode
>
())
{
return
Downcast
<
BufferRegion
>
(
arg
);
}
// Case 2: BufferLoad — convert indices to ranges (Ramp -> lanes, else
// extent=1)
if
(
const
auto
*
load
=
arg
.
as
<
BufferLoadNode
>
())
{
Array
<
Range
>
ranges
;
for
(
const
PrimExpr
&
index
:
load
->
indices
)
{
if
(
const
auto
*
ramp
=
index
.
as
<
RampNode
>
())
{
ICHECK
(
ramp
->
stride
.
as
<
IntImmNode
>
())
<<
"Ramp stride must be IntImm"
;
ICHECK_EQ
(
ramp
->
stride
.
as
<
IntImmNode
>
()
->
value
,
1
)
<<
"Only stride-1 Ramp is supported in GEMM region conversion"
;
ICHECK
(
ramp
->
lanes
.
as
<
IntImmNode
>
())
<<
"Scalable vector lanes not supported in GEMM region conversion"
;
ranges
.
push_back
(
Range
::
FromMinExtent
(
ramp
->
base
,
ramp
->
lanes
));
}
else
{
ranges
.
push_back
(
Range
::
FromMinExtent
(
index
,
1
));
}
}
return
BufferRegion
(
load
->
buffer
,
ranges
);
}
// Case 3: Call nodes
if
(
const
auto
*
call
=
arg
.
as
<
CallNode
>
())
{
// tl.region(...) — reconstruct via RegionOp
if
(
call
->
op
.
same_as
(
RegionOp
::
Get
()))
{
RegionOp
region
(
call
->
args
,
vmap
);
return
BufferRegion
(
region
->
GetBuffer
(),
region
->
GetRanges
());
}
// builtin.tvm_access_ptr(...) — map var to Buffer and take full region
if
(
call
->
op
.
same_as
(
builtin
::
tvm_access_ptr
()))
{
Var
var
=
Downcast
<
Var
>
(
call
->
args
[
1
]);
Buffer
buf
=
vmap
[
var
];
Array
<
Range
>
ranges
;
for
(
PrimExpr
extent
:
buf
->
shape
)
{
ranges
.
push_back
(
Range
(
IntImm
(
extent
->
dtype
,
0
),
extent
));
}
return
BufferRegion
(
buf
,
ranges
);
}
}
LOG
(
FATAL
)
<<
"Unsupported GEMM argument for BufferRegion: "
<<
arg
;
throw
;
// Unreachable, keeps compiler happy
}
// Build a tvm_access_ptr(handle) to the start of the 2D tile within a
// BufferRegion. Offset is computed from all but the last two dimensions; extent
// is the product of the last two extents. rw_mask: 1=read, 2=write,
// 3=readwrite.
static
PrimExpr
MakeAccessPtrFromRegion
(
const
BufferRegion
&
region
,
int
rw_mask
)
{
Buffer
buf
=
region
->
buffer
;
int
ndim
=
static_cast
<
int
>
(
buf
->
shape
.
size
());
ICHECK
(
ndim
>=
2
)
<<
"GEMM expects buffers with at least 2 dims"
;
// Compute row-major strides
std
::
vector
<
PrimExpr
>
strides
(
ndim
);
PrimExpr
one
=
make_const
(
buf
->
shape
[
0
].
dtype
(),
1
);
PrimExpr
cur
=
one
;
for
(
int
i
=
ndim
-
1
;
i
>=
0
;
--
i
)
{
strides
[
i
]
=
cur
;
cur
=
cur
*
buf
->
shape
[
i
];
}
// Offset: sum_{i in [0..ndim-3]} min_i * stride_i
PrimExpr
offset
=
make_const
(
buf
->
shape
[
0
].
dtype
(),
0
);
for
(
int
i
=
0
;
i
<
ndim
-
2
;
++
i
)
{
offset
=
offset
+
region
->
region
[
i
]
->
min
*
strides
[
i
];
}
// Extent: last two extents product (elements)
PrimExpr
extent
=
region
->
region
[
ndim
-
2
]
->
extent
*
region
->
region
[
ndim
-
1
]
->
extent
;
// ptype and return handle
PrimExpr
ptype
=
tir
::
TypeAnnotation
(
buf
->
dtype
);
Array
<
PrimExpr
>
acc_args
{
ptype
,
buf
->
data
,
offset
,
extent
,
IntImm
(
DataType
::
Int
(
32
),
rw_mask
)};
return
Call
(
DataType
::
Handle
(),
builtin
::
tvm_access_ptr
(),
acc_args
);
}
Gemm
::
Gemm
(
Array
<
PrimExpr
>
args
,
BufferMap
vmap
)
{
ObjectPtr
<
GemmNode
>
node
=
make_object
<
GemmNode
>
();
node
->
Aptr
=
args
[
0
];
node
->
Bptr
=
args
[
1
];
node
->
Cptr
=
args
[
2
];
node
->
A
=
vmap
[
GetVarFromAccessPtr
(
node
->
Aptr
)];
node
->
B
=
vmap
[
GetVarFromAccessPtr
(
node
->
Bptr
)];
node
->
C
=
vmap
[
GetVarFromAccessPtr
(
node
->
Cptr
)];
node
->
trans_A
=
args
[
3
].
as
<
Bool
>
().
value
();
node
->
trans_B
=
args
[
4
].
as
<
Bool
>
().
value
();
node
->
M
=
args
[
5
].
as
<
IntImm
>
().
value
()
->
value
;
node
->
N
=
args
[
6
].
as
<
IntImm
>
().
value
()
->
value
;
node
->
K
=
args
[
7
].
as
<
IntImm
>
().
value
()
->
value
;
node
->
policy
=
GemmWarpPolicy
(
args
[
8
].
as
<
IntImm
>
().
value
()
->
value
);
node
->
clear_accum
=
args
[
9
].
as
<
PrimExpr
>
().
value
();
node
->
stride_A
=
args
[
10
].
as
<
IntImm
>
().
value
()
->
value
;
node
->
stride_B
=
args
[
11
].
as
<
IntImm
>
().
value
()
->
value
;
node
->
offset_A
=
args
[
12
].
as
<
IntImm
>
().
value
()
->
value
;
node
->
offset_B
=
args
[
13
].
as
<
IntImm
>
().
value
()
->
value
;
ObjectPtr
<
GemmNode
>
node
=
tvm
::
ffi
::
make_object
<
GemmNode
>
();
node
->
aRegion_
=
NormalizeToBufferRegion
(
args
[
0
],
vmap
);
node
->
bRegion_
=
NormalizeToBufferRegion
(
args
[
1
],
vmap
);
node
->
cRegion_
=
NormalizeToBufferRegion
(
args
[
2
],
vmap
);
node
->
a_
=
node
->
aRegion_
->
buffer
;
node
->
b_
=
node
->
bRegion_
->
buffer
;
node
->
c_
=
node
->
cRegion_
->
buffer
;
node
->
transA_
=
args
[
3
].
as
<
Bool
>
().
value
();
node
->
transB_
=
args
[
4
].
as
<
Bool
>
().
value
();
node
->
m_
=
args
[
5
].
as
<
IntImm
>
().
value
()
->
value
;
node
->
n_
=
args
[
6
].
as
<
IntImm
>
().
value
()
->
value
;
node
->
k_
=
args
[
7
].
as
<
IntImm
>
().
value
()
->
value
;
node
->
policy_
=
GemmWarpPolicy
(
args
[
8
].
as
<
IntImm
>
().
value
()
->
value
);
node
->
clearAccum_
=
args
[
9
].
as
<
PrimExpr
>
().
value
();
node
->
strideA_
=
args
[
10
].
as
<
IntImm
>
().
value
()
->
value
;
node
->
strideB_
=
args
[
11
].
as
<
IntImm
>
().
value
()
->
value
;
node
->
offsetA_
=
args
[
12
].
as
<
IntImm
>
().
value
()
->
value
;
node
->
offsetB_
=
args
[
13
].
as
<
IntImm
>
().
value
()
->
value
;
if
(
args
.
size
()
>
14
)
{
node
->
kPack
=
args
[
14
].
as
<
IntImm
>
().
value
()
->
value
;
if
(
node
->
kPack
!=
1
&&
node
->
kPack
!=
2
)
{
node
->
kPack
_
=
args
[
14
].
as
<
IntImm
>
().
value
()
->
value
;
if
(
node
->
kPack
_
!=
1
&&
node
->
kPack
_
!=
2
)
{
ICHECK
(
false
)
<<
"kPack must be 1 or 2"
;
}
}
if
(
args
.
size
()
>
15
)
{
node
->
wg
_w
ait
=
args
[
15
].
as
<
IntImm
>
().
value
()
->
value
;
node
->
wg
W
ait
_
=
args
[
15
].
as
<
IntImm
>
().
value
()
->
value
;
}
node
->
mbar
p
tr
=
args
[
16
];
if
(
node
->
mbar
p
tr
.
as
<
CallNode
>
())
{
node
->
mbar
=
vmap
[
GetVarFromAccessPtr
(
node
->
mbar
p
tr
)];
node
->
mbar
P
tr
_
=
args
[
16
];
if
(
node
->
mbar
P
tr
_
.
as
<
CallNode
>
())
{
node
->
mbar
_
=
vmap
[
GetVarFromAccessPtr
(
node
->
mbar
P
tr
_
)];
}
else
{
node
->
mbar
=
std
::
nullopt
;
node
->
mbar
_
=
std
::
nullopt
;
}
node
->
C_
coords
=
Array
<
PrimExpr
>
(
node
->
c
C
oords
_
=
Array
<
PrimExpr
>
(
{
args
[
17
].
as
<
PrimExpr
>
().
value
(),
args
[
18
].
as
<
PrimExpr
>
().
value
()});
data_
=
std
::
move
(
node
);
}
...
...
@@ -160,46 +185,45 @@ Gemm::Gemm(Array<PrimExpr> args, BufferMap vmap) {
* @return TileOperator A Gemm operator that owns a copy of this node.
*/
TileOperator
GemmNode
::
Clone
()
const
{
auto
op
=
make_object
<
GemmNode
>
(
*
this
);
auto
op
=
tvm
::
ffi
::
make_object
<
GemmNode
>
(
*
this
);
return
Gemm
(
op
);
}
bool
GemmNode
::
A
llowT
CGEN5MMA
(
Target
target
)
const
{
bool
GemmNode
::
a
llowT
cgen5Mma
(
Target
target
)
const
{
return
TargetIsSm100
(
target
)
&&
((
A
.
scope
()
==
"shared.dyn"
||
A
.
scope
()
==
"shared"
||
A
.
scope
()
==
"shared.tmem"
)
&&
(
B
.
scope
()
==
"shared.dyn"
||
B
.
scope
()
==
"shared"
)
&&
C
.
scope
()
==
"shared.tmem"
)
&&
GetTCGEN5MMAMeta
(
M
,
N
,
K
,
A
->
dtype
,
C
->
dtype
).
first
;
((
a_
.
scope
()
==
"shared.dyn"
||
a_
.
scope
()
==
"shared"
||
a_
.
scope
()
==
"shared.tmem"
)
&&
(
b_
.
scope
()
==
"shared.dyn"
||
b_
.
scope
()
==
"shared"
)
&&
c_
.
scope
()
==
"shared.tmem"
)
&&
GetTCGEN5MMAMeta
(
m_
,
n_
,
k_
,
a_
->
dtype
,
c_
->
dtype
).
first
;
}
bool
GemmNode
::
A
llowW
GMMA
(
int
block_size
,
Target
target
)
const
{
bool
GemmNode
::
a
llowW
gmma
(
int
block_size
,
Target
target
)
const
{
tvm
::
transform
::
PassContext
ctxt
=
tvm
::
transform
::
PassContext
::
Current
();
int
warp_size
=
TargetGetWarpSize
(
target
);
int
num_warps
=
block_size
/
warp_size
;
return
!
ctxt
->
GetConfig
(
kDisableWGMMA
,
Optional
<
Bool
>
()).
value_or
(
false
)
&&
TargetIsHopper
(
target
)
&&
(
this
->
M
>=
64
)
&&
(
num_warps
%
4
==
0
)
&&
C
heckW
GMMA
();
TargetIsHopper
(
target
)
&&
(
this
->
m_
>=
64
)
&&
(
num_warps
%
4
==
0
)
&&
c
heckW
gmma
();
}
GemmInst
GemmNode
::
GetGemmInst
(
int
block_size
,
Target
target
)
const
{
bool
allow_tcgen5mma
=
AllowTCGEN5MMA
(
target
);
bool
allow_wgmma
=
AllowWGMMA
(
block_size
,
target
);
if
(
allow_tcgen5mma
)
{
GemmInst
GemmNode
::
getGemmInst
(
int
block_size
,
Target
target
)
const
{
if
(
allowTcgen5Mma
(
target
))
{
return
GemmInst
::
kTCGEN5MMA
;
}
else
if
(
allow
_w
gmma
)
{
}
else
if
(
allow
W
gmma
(
block_size
,
target
)
)
{
return
GemmInst
::
kWGMMA
;
}
else
if
(
TargetIsCDNA
(
target
))
{
return
GemmInst
::
kMFMA
;
}
else
if
(
TargetIsCuda
(
target
))
{
return
GemmInst
::
kMMA
;
}
else
{
ICHECK
(
0
)
<<
"Unsupported target for gemm: "
<<
target
->
str
();
ICHECK
(
0
)
<<
"Unsupported target for gemm: "
<<
target
;
return
GemmInst
::
kMMA
;
}
}
std
::
pair
<
int
,
int
>
GemmWarpPolicyNode
::
C
omputeWarpPartition
(
std
::
pair
<
int
,
int
>
GemmWarpPolicyNode
::
c
omputeWarpPartition
(
int
M
,
int
N
,
int
block_size
,
Target
target
,
GemmInst
gemm_inst
)
const
{
int
num_warps
=
block_size
/
TargetGetWarpSize
(
target
);
if
(
gemm_inst
==
GemmInst
::
kTCGEN5MMA
)
{
...
...
@@ -208,7 +232,10 @@ std::pair<int, int> GemmWarpPolicyNode::ComputeWarpPartition(
int
m_warp
=
1
,
n_warp
=
1
;
constexpr
int
kMPerWarp
=
16
;
// Rows processed by a single warp
constexpr
int
kNPerWarp
=
8
;
// Columns processed by a single warp
int
kNPerWarp
=
8
;
// Columns processed by a single warp
if
(
TargetIsVolta
(
target
))
{
kNPerWarp
=
16
;
}
ICHECK
(
M
%
kMPerWarp
==
0
)
<<
"M must be divisible by "
<<
kMPerWarp
<<
", but got "
<<
M
;
ICHECK
(
N
%
kNPerWarp
==
0
)
...
...
@@ -408,51 +435,52 @@ std::pair<int, int> GemmWarpPolicyNode::ComputeWarpPartition(
* @return true if WGMMA is supported for the current buffers, dtypes, and
* transpose/shape constraints; false otherwise.
*/
bool
GemmNode
::
C
heckW
GMMA
()
const
{
if
(
B
.
scope
()
!=
"shared.dyn"
&&
B
.
scope
()
!=
"shared"
)
{
bool
GemmNode
::
c
heckW
gmma
()
const
{
if
(
b_
.
scope
()
!=
"shared.dyn"
&&
b_
.
scope
()
!=
"shared"
)
{
return
false
;
}
if
(
C
->
dtype
==
DataType
::
Float
(
16
))
{
if
(
A
->
dtype
==
DataType
::
Float
(
16
)
&&
B
->
dtype
==
DataType
::
Float
(
16
))
return
K
%
16
==
0
;
else
if
(
A
->
dtype
.
is_float8_e4m3
()
&&
B
->
dtype
.
is_float8_e4m3
())
return
(
!
trans
_
A
)
&&
trans
_
B
&&
K
%
32
==
0
;
else
if
(
A
->
dtype
.
is_float8_e4m3
()
&&
B
->
dtype
.
is_float8_e5m2
())
return
(
!
trans
_
A
)
&&
trans
_
B
&&
K
%
32
==
0
;
else
if
(
A
->
dtype
.
is_float8_e5m2
()
&&
B
->
dtype
.
is_float8_e4m3
())
return
(
!
trans
_
A
)
&&
trans
_
B
&&
K
%
32
==
0
;
else
if
(
A
->
dtype
.
is_float8_e5m2
()
&&
B
->
dtype
.
is_float8_e5m2
())
return
(
!
trans
_
A
)
&&
trans
_
B
&&
K
%
32
==
0
;
if
(
c_
->
dtype
==
DataType
::
Float
(
16
))
{
if
(
a_
->
dtype
==
DataType
::
Float
(
16
)
&&
b_
->
dtype
==
DataType
::
Float
(
16
))
return
k_
%
16
==
0
;
else
if
(
a_
->
dtype
.
is_float8_e4m3
()
&&
b_
->
dtype
.
is_float8_e4m3
())
return
(
!
transA
_
)
&&
transB
_
&&
k_
%
32
==
0
;
else
if
(
a_
->
dtype
.
is_float8_e4m3
()
&&
b_
->
dtype
.
is_float8_e5m2
())
return
(
!
transA
_
)
&&
transB
_
&&
k_
%
32
==
0
;
else
if
(
a_
->
dtype
.
is_float8_e5m2
()
&&
b_
->
dtype
.
is_float8_e4m3
())
return
(
!
transA
_
)
&&
transB
_
&&
k_
%
32
==
0
;
else
if
(
a_
->
dtype
.
is_float8_e5m2
()
&&
b_
->
dtype
.
is_float8_e5m2
())
return
(
!
transA
_
)
&&
transB
_
&&
k_
%
32
==
0
;
else
return
false
;
}
else
if
(
C
->
dtype
==
DataType
::
Float
(
32
))
{
if
(
A
->
dtype
==
DataType
::
Float
(
16
)
&&
B
->
dtype
==
DataType
::
Float
(
16
))
return
K
%
16
==
0
;
else
if
(
A
->
dtype
==
DataType
::
BFloat
(
16
)
&&
B
->
dtype
==
DataType
::
BFloat
(
16
))
return
K
%
16
==
0
;
else
if
(
A
->
dtype
==
DataType
::
Float
(
32
)
&&
B
->
dtype
==
DataType
::
Float
(
32
))
return
(
!
trans_A
)
&&
trans_B
&&
K
%
8
==
0
;
else
if
(
A
->
dtype
.
is_float8_e4m3
()
&&
B
->
dtype
.
is_float8_e4m3
())
return
(
!
trans_A
)
&&
trans_B
&&
K
%
32
==
0
;
else
if
(
A
->
dtype
.
is_float8_e4m3
()
&&
B
->
dtype
.
is_float8_e5m2
())
return
(
!
trans_A
)
&&
trans_B
&&
K
%
32
==
0
;
else
if
(
A
->
dtype
.
is_float8_e5m2
()
&&
B
->
dtype
.
is_float8_e4m3
())
return
(
!
trans_A
)
&&
trans_B
&&
K
%
32
==
0
;
else
if
(
A
->
dtype
.
is_float8_e5m2
()
&&
B
->
dtype
.
is_float8_e5m2
())
return
(
!
trans_A
)
&&
trans_B
&&
K
%
32
==
0
;
}
else
if
(
c_
->
dtype
==
DataType
::
Float
(
32
))
{
if
(
a_
->
dtype
==
DataType
::
Float
(
16
)
&&
b_
->
dtype
==
DataType
::
Float
(
16
))
return
k_
%
16
==
0
;
else
if
(
a_
->
dtype
==
DataType
::
BFloat
(
16
)
&&
b_
->
dtype
==
DataType
::
BFloat
(
16
))
return
k_
%
16
==
0
;
else
if
(
a_
->
dtype
==
DataType
::
Float
(
32
)
&&
b_
->
dtype
==
DataType
::
Float
(
32
))
return
(
!
transA_
)
&&
transB_
&&
k_
%
8
==
0
;
else
if
(
a_
->
dtype
.
is_float8_e4m3
()
&&
b_
->
dtype
.
is_float8_e4m3
())
return
(
!
transA_
)
&&
transB_
&&
k_
%
32
==
0
;
else
if
(
a_
->
dtype
.
is_float8_e4m3
()
&&
b_
->
dtype
.
is_float8_e5m2
())
return
(
!
transA_
)
&&
transB_
&&
k_
%
32
==
0
;
else
if
(
a_
->
dtype
.
is_float8_e5m2
()
&&
b_
->
dtype
.
is_float8_e4m3
())
return
(
!
transA_
)
&&
transB_
&&
k_
%
32
==
0
;
else
if
(
a_
->
dtype
.
is_float8_e5m2
()
&&
b_
->
dtype
.
is_float8_e5m2
())
return
(
!
transA_
)
&&
transB_
&&
k_
%
32
==
0
;
else
return
false
;
}
else
if
(
C
->
dtype
==
DataType
::
Int
(
32
))
{
if
(
A
->
dtype
==
DataType
::
Int
(
8
)
&&
B
->
dtype
==
DataType
::
Int
(
8
))
return
(
!
trans
_
A
)
&&
trans
_
B
&&
K
%
32
==
0
;
else
if
(
A
->
dtype
==
DataType
::
Int
(
8
)
&&
B
->
dtype
==
DataType
::
UInt
(
8
))
return
(
!
trans
_
A
)
&&
trans
_
B
&&
K
%
32
==
0
;
else
if
(
A
->
dtype
==
DataType
::
UInt
(
8
)
&&
B
->
dtype
==
DataType
::
Int
(
8
))
return
(
!
trans
_
A
)
&&
trans
_
B
&&
K
%
32
==
0
;
else
if
(
A
->
dtype
==
DataType
::
UInt
(
8
)
&&
B
->
dtype
==
DataType
::
UInt
(
8
))
return
(
!
trans
_
A
)
&&
trans
_
B
&&
K
%
32
==
0
;
}
else
if
(
c_
->
dtype
==
DataType
::
Int
(
32
))
{
if
(
a_
->
dtype
==
DataType
::
Int
(
8
)
&&
b_
->
dtype
==
DataType
::
Int
(
8
))
return
(
!
transA
_
)
&&
transB
_
&&
k_
%
32
==
0
;
else
if
(
a_
->
dtype
==
DataType
::
Int
(
8
)
&&
b_
->
dtype
==
DataType
::
UInt
(
8
))
return
(
!
transA
_
)
&&
transB
_
&&
k_
%
32
==
0
;
else
if
(
a_
->
dtype
==
DataType
::
UInt
(
8
)
&&
b_
->
dtype
==
DataType
::
Int
(
8
))
return
(
!
transA
_
)
&&
transB
_
&&
k_
%
32
==
0
;
else
if
(
a_
->
dtype
==
DataType
::
UInt
(
8
)
&&
b_
->
dtype
==
DataType
::
UInt
(
8
))
return
(
!
transA
_
)
&&
transB
_
&&
k_
%
32
==
0
;
else
return
false
;
}
else
{
...
...
@@ -476,8 +504,8 @@ bool GemmNode::CheckWGMMA() const {
*/
static
int
GetArchInt
(
Target
target
)
{
int
arch_int
=
0
;
auto
s
=
target
->
GetAttr
<
String
>
(
"arch"
);
ICHECK
(
s
.
defined
());
auto
s
=
target
->
GetAttr
<
tvm
::
ffi
::
String
>
(
"arch"
);
ICHECK
(
s
.
has_value
());
std
::
string
arch
=
s
.
value
();
if
(
arch
.
rfind
(
"sm_"
,
0
)
==
0
)
{
arch_int
=
std
::
stoi
(
arch
.
substr
(
3
));
...
...
@@ -502,56 +530,61 @@ static int GetArchInt(Target target) {
*/
Stmt
GemmNode
::
Lower
(
const
LowerArgs
&
T
,
arith
::
Analyzer
*
analyzer
)
const
{
auto
block_size
=
*
as_const_int
(
T
.
thread_bounds
->
extent
);
GemmInst
gemm_inst
=
G
etGemmInst
(
block_size
,
T
.
target
);
GemmInst
gemm_inst
=
g
etGemmInst
(
block_size
,
T
.
target
);
auto
[
warp_m
,
warp_n
]
=
policy
->
ComputeWarpPartition
(
M
,
N
,
block_size
,
T
.
target
,
gemm_inst
);
policy_
->
computeWarpPartition
(
m_
,
n_
,
block_size
,
T
.
target
,
gemm_inst
);
// Build access pointers from regions locally
PrimExpr
Aptr
=
MakeAccessPtrFromRegion
(
aRegion_
,
/*r*/
1
);
PrimExpr
Bptr
=
MakeAccessPtrFromRegion
(
bRegion_
,
/*r*/
1
);
PrimExpr
Cptr
=
MakeAccessPtrFromRegion
(
cRegion_
,
/*rw*/
3
);
std
::
stringstream
ss
;
std
::
string
op_name
;
if
(
gemm_inst
==
GemmInst
::
kTCGEN5MMA
)
{
auto
[
can_use_tcgen5mma
,
meta
]
=
GetTCGEN5MMAMeta
(
M
,
N
,
K
,
A
->
dtype
,
C
->
dtype
);
GetTCGEN5MMAMeta
(
m_
,
n_
,
k_
,
a_
->
dtype
,
c_
->
dtype
);
ICHECK
(
can_use_tcgen5mma
);
ICHECK
(
B
.
scope
()
==
"shared.dyn"
||
B
.
scope
()
==
"shared"
);
ICHECK
(
C
.
scope
()
==
"shared.tmem"
);
ICHECK
(
mbar
.
has_value
())
<<
"mbar must be provided for TCGEN5MMA"
;
if
(
A
.
scope
()
==
"shared.tmem"
)
{
ICHECK
(
b_
.
scope
()
==
"shared.dyn"
||
b_
.
scope
()
==
"shared"
);
ICHECK
(
c_
.
scope
()
==
"shared.tmem"
);
ICHECK
(
mbar
_
.
has_value
())
<<
"mbar must be provided for TCGEN5MMA"
;
if
(
a_
.
scope
()
==
"shared.tmem"
)
{
op_name
=
"tl::tcgen5mma_gemm_ts"
;
}
else
if
(
A
.
scope
()
==
"shared.dyn"
||
A
.
scope
()
==
"shared"
)
{
}
else
if
(
a_
.
scope
()
==
"shared.dyn"
||
a_
.
scope
()
==
"shared"
)
{
op_name
=
"tl::tcgen5mma_gemm_ss"
;
}
else
{
ICHECK
(
0
)
<<
"Unsupported A scope for TCGEN5MMA: "
<<
A
.
scope
();
// If this is triggered, it means Tilelang has bugs.
<<
a_
.
scope
();
// If this is triggered, it means Tilelang has bugs.
}
ICHECK
(
wg
_w
ait
==
-
1
)
ICHECK
(
wg
W
ait
_
==
-
1
)
<<
"Currently only wg_wait == -1 is supported for TCGEN5MMA. Please "
"use "
"wg_wait = -1 and manually synchronize with mbarrier."
;
std
::
string
accum_dtype
=
""
;
if
(
C
->
dtype
.
is_float
())
{
if
(
C
->
dtype
.
bits
()
==
32
)
{
if
(
c_
->
dtype
.
is_float
())
{
if
(
c_
->
dtype
.
bits
()
==
32
)
{
accum_dtype
=
"float"
;
}
}
ICHECK
(
!
accum_dtype
.
empty
())
<<
"Unsupported C dtype for TCGEN5MMA: "
<<
C
->
dtype
;
ss
<<
op_name
<<
"<"
<<
M
<<
", "
<<
N
<<
", "
<<
K
<<
", "
;
<<
"Unsupported C dtype for TCGEN5MMA: "
<<
c_
->
dtype
;
ss
<<
op_name
<<
"<"
<<
m_
<<
", "
<<
n_
<<
", "
<<
k_
<<
", "
;
ss
<<
meta
.
atom_m
<<
", "
<<
meta
.
atom_n
<<
", "
<<
meta
.
atom_k
<<
", "
;
ss
<<
trans
_
A
<<
", "
<<
trans
_
B
<<
", "
;
ss
<<
transA
_
<<
", "
<<
transB
_
<<
", "
;
ss
<<
accum_dtype
;
ss
<<
">"
;
auto
C_buffer
=
T
.
buffer_remap
.
count
(
C
)
?
T
.
buffer_remap
[
C
]
:
C
;
auto
C_buffer
=
T
.
buffer_remap
.
count
(
c_
)
?
T
.
buffer_remap
[
c_
]
:
c_
;
Array
<
PrimExpr
>
new_args
;
new_args
.
push_back
(
StringImm
(
ss
.
str
()));
new_args
.
push_back
(
Aptr
);
new_args
.
push_back
(
Bptr
);
new_args
.
push_back
(
BufferLoad
(
C_buffer
,
C_
coords
));
new_args
.
push_back
(
mbar
p
tr
);
new_args
.
push_back
(
clear
_a
ccum
);
new_args
.
push_back
(
BufferLoad
(
C_buffer
,
c
C
oords
_
));
new_args
.
push_back
(
mbar
P
tr
_
);
new_args
.
push_back
(
clear
A
ccum
_
);
auto
new_call
=
Call
(
DataType
::
Handle
(),
builtin
::
call_extern
(),
new_args
);
// Since TCGEN5MMA atoms provided by CUTLASS always have an internal
...
...
@@ -576,47 +609,49 @@ Stmt GemmNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
}
}
if
(
A
.
scope
()
==
"local.fragment"
)
{
ICHECK
(
B
.
scope
()
!=
"local.fragment"
);
if
(
a_
.
scope
()
==
"local.fragment"
)
{
ICHECK
(
b_
.
scope
()
!=
"local.fragment"
);
ICHECK
(
!
transA_
)
<<
"gemm_rs requires the A operand to be in non-transposed layout."
;
op_name
=
"tl::gemm_rs"
;
}
else
if
(
B
.
scope
()
==
"local.fragment"
)
{
}
else
if
(
b_
.
scope
()
==
"local.fragment"
)
{
op_name
=
"tl::gemm_sr"
;
}
else
{
op_name
=
"tl::gemm_ss"
;
}
ICHECK
(
C
.
scope
()
==
"local.fragment"
);
ICHECK
(
c_
.
scope
()
==
"local.fragment"
);
ss
<<
op_name
<<
"<"
<<
M
<<
", "
<<
N
<<
", "
<<
K
<<
", "
;
ss
<<
op_name
<<
"<"
<<
m_
<<
", "
<<
n_
<<
", "
<<
k_
<<
", "
;
ss
<<
warp_m
<<
", "
<<
warp_n
<<
", "
;
ss
<<
trans
_
A
<<
", "
<<
trans
_
B
;
auto
clear_accum_bool
=
clear
_a
ccum
.
as
<
Bool
>
();
ss
<<
transA
_
<<
", "
<<
transB
_
;
auto
clear_accum_bool
=
clear
A
ccum
_
.
as
<
Bool
>
();
ICHECK
(
clear_accum_bool
.
has_value
())
<<
"clear_accum must be a constant Bool type, got "
<<
clear
_a
ccum
;
<<
"clear_accum must be a constant Bool type, got "
<<
clear
A
ccum
_
;
ss
<<
", "
<<
bool
(
clear_accum_bool
.
value
());
if
(
TargetIsCuda
(
T
.
target
)
&&
(
GetArchInt
(
T
.
target
)
>=
75
))
{
ss
<<
", "
<<
stride
_
A
<<
", "
<<
stride
_
B
;
ss
<<
", "
<<
offset
_
A
<<
", "
<<
offset
_
B
;
ss
<<
", "
<<
strideA
_
<<
", "
<<
strideB
_
;
ss
<<
", "
<<
offsetA
_
<<
", "
<<
offsetB
_
;
}
if
(
TargetIsCDNA
(
T
.
target
))
{
// for cdna gemm, we need to specify kPack
ss
<<
", "
<<
kPack
;
ss
<<
", "
<<
kPack
_
;
}
else
if
(
TargetIsHopper
(
T
.
target
))
{
ss
<<
", "
<<
(
gemm_inst
==
GemmInst
::
kWGMMA
?
"true"
:
"false"
);
}
// Emit wg_wait if necessary
if
(
TargetIsHopper
(
T
.
target
))
{
if
(
wg
_w
ait
!=
0
)
{
ss
<<
", "
<<
wg
_w
ait
;
if
(
wg
W
ait
_
!=
0
)
{
ss
<<
", "
<<
wg
W
ait
_
;
}
}
else
if
(
TargetIsSm100
(
T
.
target
))
{
// NOTE On sm100, only the leading thread issues the TCGEN5MMA instruction
// but all threads need to wait, so we emit another statement for cases
// where wg_wait == 0.
ICHECK
(
wg
_w
ait
==
0
||
wg
_w
ait
==
-
1
)
ICHECK
(
wg
W
ait
_
==
0
||
wg
W
ait
_
==
-
1
)
<<
"wg_wait must be 0 or -1 for Sm100"
;
}
else
{
ICHECK
(
wg
_w
ait
==
0
)
ICHECK
(
wg
W
ait
_
==
0
)
<<
"wg_wait must be 0 for non-Hopper and non-Sm100 targets"
;
}
ss
<<
">"
;
...
...
@@ -652,151 +687,152 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T,
LayoutMap
results
;
auto
thread_range
=
T
.
thread_bounds
;
auto
block_size
=
*
as_const_int
(
thread_range
->
extent
);
GemmInst
gemm_inst
=
G
etGemmInst
(
block_size
,
T
.
target
);
GemmInst
gemm_inst
=
g
etGemmInst
(
block_size
,
T
.
target
);
auto
[
warp_m
,
warp_n
]
=
policy
->
C
omputeWarpPartition
(
M
,
N
,
block_size
,
T
.
target
,
gemm_inst
);
policy
_
->
c
omputeWarpPartition
(
m_
,
n_
,
block_size
,
T
.
target
,
gemm_inst
);
if
(
TargetIsVolta
(
T
.
target
))
{
ICHECK
(
C
.
scope
()
==
"local.fragment"
)
ICHECK
(
c_
.
scope
()
==
"local.fragment"
)
<<
"Volta gemm only supports C in local.fragment scope, got "
<<
C
.
scope
();
<<
c_
.
scope
();
auto
fragment
=
makeGemmVoltaFragmentC
(
m_
,
n_
,
m_
/
warp_m
,
n_
/
warp_n
,
c_
->
dtype
.
bits
());
results
.
Set
(
c_
,
fragment
->
BindThreadRange
(
thread_range
));
if
(
a_
.
scope
()
==
"shared"
||
a_
.
scope
()
==
"shared.dyn"
)
{
int
dim_A
=
a_
->
shape
.
size
();
results
.
Set
(
a_
,
makeGemmVoltaABLayout
(
*
as_const_int
(
a_
->
shape
[
dim_A
-
2
]),
*
as_const_int
(
a_
->
shape
[
dim_A
-
1
]),
true
,
!
transA_
));
}
else
if
(
a_
.
scope
()
==
"local.fragment"
)
{
ICHECK
(
transA_
==
false
);
auto
fragment
=
makeGemmVoltaFragmentC
(
M
,
N
,
M
/
warp_m
,
N
/
warp_n
,
C
->
dtype
.
bits
());
results
.
Set
(
C
,
fragment
->
BindThreadRange
(
thread_range
));
if
(
A
.
scope
()
==
"shared"
||
A
.
scope
()
==
"shared.dyn"
)
{
int
dim_A
=
A
->
shape
.
size
();
results
.
Set
(
A
,
makeGemmVoltaABLayout
(
*
as_const_int
(
A
->
shape
[
dim_A
-
2
]),
*
as_const_int
(
A
->
shape
[
dim_A
-
1
]),
true
,
!
trans_A
));
}
else
if
(
A
.
scope
()
==
"local.fragment"
)
{
ICHECK
(
trans_A
==
false
);
auto
fragment
=
makeGemmVoltaFragmentA
(
M
,
N
,
K
,
M
/
warp_m
,
N
/
warp_n
);
results
.
Set
(
A
,
fragment
->
BindThreadRange
(
thread_range
));
makeGemmVoltaFragmentA
(
m_
,
n_
,
k_
,
m_
/
warp_m
,
n_
/
warp_n
);
results
.
Set
(
a_
,
fragment
->
BindThreadRange
(
thread_range
));
}
else
{
ICHECK
(
0
);
}
ICHECK
(
B
.
scope
()
==
"shared"
||
B
.
scope
()
==
"shared.dyn"
);
int
dim_B
=
B
->
shape
.
size
();
results
.
Set
(
B
,
makeGemmVoltaABLayout
(
*
as_const_int
(
B
->
shape
[
dim_B
-
2
]),
*
as_const_int
(
B
->
shape
[
dim_B
-
1
]),
false
,
trans
_
B
));
ICHECK
(
b_
.
scope
()
==
"shared"
||
b_
.
scope
()
==
"shared.dyn"
);
int
dim_B
=
b_
->
shape
.
size
();
results
.
Set
(
b_
,
makeGemmVoltaABLayout
(
*
as_const_int
(
b_
->
shape
[
dim_B
-
2
]),
*
as_const_int
(
b_
->
shape
[
dim_B
-
1
]),
false
,
transB
_
));
}
else
if
(
TargetIsAmpere
(
T
.
target
)
||
TargetIsTuring
(
T
.
target
)
||
TargetIsSM120
(
T
.
target
)
||
(
TargetIsSm100
(
T
.
target
)
&&
gemm_inst
==
GemmInst
::
kMMA
))
{
ICHECK
(
C
.
scope
()
==
"local.fragment"
)
<<
"MMA only supports C in local.fragment scope, got "
<<
C
.
scope
();
ICHECK
(
c_
.
scope
()
==
"local.fragment"
)
<<
"MMA only supports C in local.fragment scope, got "
<<
c_
.
scope
();
auto
fragment
=
makeGemmFragmentC
(
M
,
N
,
M
/
warp_m
,
N
/
warp_n
,
C
->
dtype
.
bits
());
results
.
Set
(
C
,
fragment
->
BindThreadRange
(
thread_range
));
if
(
A
.
scope
()
==
"shared"
||
A
.
scope
()
==
"shared.dyn"
)
{
int
dim_A
=
A
->
shape
.
size
();
const
int64_t
mat_stride
=
*
as_const_int
(
A
->
shape
[
dim_A
-
2
]);
const
int64_t
mat_continuous
=
*
as_const_int
(
A
->
shape
[
dim_A
-
1
]);
results
.
Set
(
A
,
makeGemmFragmentC
(
m_
,
n_
,
m_
/
warp_m
,
n_
/
warp_n
,
c_
->
dtype
.
bits
());
results
.
Set
(
c_
,
fragment
->
BindThreadRange
(
thread_range
));
if
(
a_
.
scope
()
==
"shared"
||
a_
.
scope
()
==
"shared.dyn"
)
{
int
dim_A
=
a_
->
shape
.
size
();
const
int64_t
mat_stride
=
*
as_const_int
(
a_
->
shape
[
dim_A
-
2
]);
const
int64_t
mat_continuous
=
*
as_const_int
(
a_
->
shape
[
dim_A
-
1
]);
results
.
Set
(
a_
,
makeGemmABLayout
(
mat_stride
,
mat_continuous
,
mat_continuous
,
A
->
dtype
.
bits
(),
!
trans
_
A
));
}
else
if
(
A
.
scope
()
==
"local.fragment"
)
{
auto
fragment
=
makeGemmFragmentA
(
M
,
N
,
K
,
M
/
warp_m
,
N
/
warp_n
,
A
->
dtype
.
bits
(),
trans
_
A
);
results
.
Set
(
A
,
fragment
->
BindThreadRange
(
thread_range
));
a_
->
dtype
.
bits
(),
!
transA
_
));
}
else
if
(
a_
.
scope
()
==
"local.fragment"
)
{
auto
fragment
=
makeGemmFragmentA
(
m_
,
n_
,
k_
,
m_
/
warp_m
,
n_
/
warp_n
,
a_
->
dtype
.
bits
(),
transA
_
);
results
.
Set
(
a_
,
fragment
->
BindThreadRange
(
thread_range
));
}
else
{
ICHECK
(
0
);
}
if
(
B
.
scope
()
==
"shared"
||
B
.
scope
()
==
"shared.dyn"
)
{
int
dim_B
=
B
->
shape
.
size
();
const
int64_t
mat_stride
=
*
as_const_int
(
B
->
shape
[
dim_B
-
2
]);
const
int64_t
mat_continuous
=
*
as_const_int
(
B
->
shape
[
dim_B
-
1
]);
results
.
Set
(
B
,
if
(
b_
.
scope
()
==
"shared"
||
b_
.
scope
()
==
"shared.dyn"
)
{
int
dim_B
=
b_
->
shape
.
size
();
const
int64_t
mat_stride
=
*
as_const_int
(
b_
->
shape
[
dim_B
-
2
]);
const
int64_t
mat_continuous
=
*
as_const_int
(
b_
->
shape
[
dim_B
-
1
]);
results
.
Set
(
b_
,
makeGemmABLayout
(
mat_stride
,
mat_continuous
,
mat_continuous
,
B
->
dtype
.
bits
(),
trans
_
B
));
}
else
if
(
B
.
scope
()
==
"local.fragment"
)
{
b_
->
dtype
.
bits
(),
transB
_
));
}
else
if
(
b_
.
scope
()
==
"local.fragment"
)
{
auto
fragment
=
makeGemmFragmentB
(
M
,
N
,
K
,
M
/
warp_m
,
N
/
warp_n
,
trans
_
B
);
results
.
Set
(
B
,
fragment
->
BindThreadRange
(
thread_range
));
makeGemmFragmentB
(
m_
,
n_
,
k_
,
m_
/
warp_m
,
n_
/
warp_n
,
transB
_
);
results
.
Set
(
b_
,
fragment
->
BindThreadRange
(
thread_range
));
}
else
{
ICHECK
(
0
);
}
}
else
if
(
TargetIsHopper
(
T
.
target
))
{
ICHECK
(
C
.
scope
()
==
"local.fragment"
)
ICHECK
(
c_
.
scope
()
==
"local.fragment"
)
<<
(
gemm_inst
==
GemmInst
::
kWGMMA
?
"WGMMA "
:
"MMA "
)
<<
"only supports C in local.fragment scope, got "
<<
C
.
scope
();
auto
fragment
=
gemm_inst
==
GemmInst
::
kWGMMA
?
makeGemmFragmentCHopper
(
M
,
N
,
M
/
warp_m
,
N
/
warp_n
,
C
->
dtype
.
bits
())
:
makeGemmFragmentC
(
M
,
N
,
M
/
warp_m
,
N
/
warp_n
,
C
->
dtype
.
bits
());
results
.
Set
(
C
,
fragment
->
BindThreadRange
(
thread_range
));
if
(
A
.
scope
()
==
"shared"
||
A
.
scope
()
==
"shared.dyn"
)
{
int
dim_A
=
A
->
shape
.
size
();
const
int64_t
mat_stride
=
*
as_const_int
(
A
->
shape
[
dim_A
-
2
]);
const
int64_t
mat_continuous
=
*
as_const_int
(
A
->
shape
[
dim_A
-
1
]);
<<
"only supports C in local.fragment scope, got "
<<
c_
.
scope
();
auto
fragment
=
gemm_inst
==
GemmInst
::
kWGMMA
?
makeGemmFragmentCHopper
(
m_
,
n_
,
m_
/
warp_m
,
n_
/
warp_n
,
c_
->
dtype
.
bits
())
:
makeGemmFragmentC
(
m_
,
n_
,
m_
/
warp_m
,
n_
/
warp_n
,
c_
->
dtype
.
bits
());
results
.
Set
(
c_
,
fragment
->
BindThreadRange
(
thread_range
));
if
(
a_
.
scope
()
==
"shared"
||
a_
.
scope
()
==
"shared.dyn"
)
{
int
dim_A
=
a_
->
shape
.
size
();
const
int64_t
mat_stride
=
*
as_const_int
(
a_
->
shape
[
dim_A
-
2
]);
const
int64_t
mat_continuous
=
*
as_const_int
(
a_
->
shape
[
dim_A
-
1
]);
const
int64_t
continuity
=
trans
_
A
?
4
*
mat_continuous
/
warp_m
:
mat_continuous
;
transA
_
?
4
*
mat_continuous
/
warp_m
:
mat_continuous
;
auto
ABLayout
=
gemm_inst
==
GemmInst
::
kWGMMA
?
makeGemmABLayoutHopper
(
mat_stride
,
mat_continuous
,
continuity
,
A
->
dtype
.
bits
(),
!
trans
_
A
)
a_
->
dtype
.
bits
(),
!
transA
_
)
:
makeGemmABLayout
(
mat_stride
,
mat_continuous
,
mat_continuous
,
A
->
dtype
.
bits
(),
!
trans
_
A
);
results
.
Set
(
A
,
ABLayout
);
a_
->
dtype
.
bits
(),
!
transA
_
);
results
.
Set
(
a_
,
ABLayout
);
}
else
{
auto
fragment
=
makeGemmFragmentA
(
M
,
N
,
K
,
M
/
warp_m
,
N
/
warp_n
,
A
->
dtype
.
bits
(),
trans
_
A
);
results
.
Set
(
A
,
fragment
->
BindThreadRange
(
thread_range
));
}
if
(
B
.
scope
()
==
"shared"
||
B
.
scope
()
==
"shared.dyn"
)
{
int
dim_B
=
B
->
shape
.
size
();
const
int64_t
mat_stride
=
*
as_const_int
(
B
->
shape
[
dim_B
-
2
]);
const
int64_t
mat_continuous
=
*
as_const_int
(
B
->
shape
[
dim_B
-
1
]);
auto
fragment
=
makeGemmFragmentA
(
m_
,
n_
,
k_
,
m_
/
warp_m
,
n_
/
warp_n
,
a_
->
dtype
.
bits
(),
transA
_
);
results
.
Set
(
a_
,
fragment
->
BindThreadRange
(
thread_range
));
}
if
(
b_
.
scope
()
==
"shared"
||
b_
.
scope
()
==
"shared.dyn"
)
{
int
dim_B
=
b_
->
shape
.
size
();
const
int64_t
mat_stride
=
*
as_const_int
(
b_
->
shape
[
dim_B
-
2
]);
const
int64_t
mat_continuous
=
*
as_const_int
(
b_
->
shape
[
dim_B
-
1
]);
const
int64_t
continuity
=
trans
_
B
?
mat_continuous
:
mat_continuous
/
warp_n
;
transB
_
?
mat_continuous
:
mat_continuous
/
warp_n
;
auto
ABLayout
=
gemm_inst
==
GemmInst
::
kWGMMA
?
makeGemmABLayoutHopper
(
mat_stride
,
mat_continuous
,
continuity
,
B
->
dtype
.
bits
(),
trans
_
B
)
b_
->
dtype
.
bits
(),
transB
_
)
:
makeGemmABLayout
(
mat_stride
,
mat_continuous
,
mat_continuous
,
B
->
dtype
.
bits
(),
trans
_
B
);
results
.
Set
(
B
,
ABLayout
);
b_
->
dtype
.
bits
(),
transB
_
);
results
.
Set
(
b_
,
ABLayout
);
}
else
{
auto
fragment
=
makeGemmFragmentB
(
M
,
N
,
K
,
M
/
warp_m
,
N
/
warp_n
,
trans
_
B
);
results
.
Set
(
B
,
fragment
->
BindThreadRange
(
thread_range
));
makeGemmFragmentB
(
m_
,
n_
,
k_
,
m_
/
warp_m
,
n_
/
warp_n
,
transB
_
);
results
.
Set
(
b_
,
fragment
->
BindThreadRange
(
thread_range
));
}
}
else
if
(
gemm_inst
==
GemmInst
::
kTCGEN5MMA
)
{
ICHECK
(
C
.
scope
()
==
"shared.tmem"
)
<<
"TCGEN5MMA only supports C in shared.tmem scope, got "
<<
C
.
scope
();
ICHECK
(
A
.
scope
()
==
"shared.dyn"
||
A
.
scope
()
==
"shared"
)
ICHECK
(
c_
.
scope
()
==
"shared.tmem"
)
<<
"TCGEN5MMA only supports C in shared.tmem scope, got "
<<
c_
.
scope
();
ICHECK
(
a_
.
scope
()
==
"shared.dyn"
||
a_
.
scope
()
==
"shared"
)
<<
"Current TCGEN5MMA only supports A in shared.dyn scope"
;
auto
[
can_use_tcgen5mma
,
meta
]
=
GetTCGEN5MMAMeta
(
M
,
N
,
K
,
A
->
dtype
,
C
->
dtype
);
GetTCGEN5MMAMeta
(
m_
,
n_
,
k_
,
a_
->
dtype
,
c_
->
dtype
);
ICHECK
(
can_use_tcgen5mma
);
{
int
dim_A
=
A
->
shape
.
size
();
const
int64_t
mat_stride
=
*
as_const_int
(
A
->
shape
[
dim_A
-
2
]);
const
int64_t
mat_continuous
=
*
as_const_int
(
A
->
shape
[
dim_A
-
1
]);
results
.
Set
(
A
,
makeGemmABLayoutSm100
(
mat_stride
,
mat_continuous
,
mat_continuous
,
A
->
dtype
.
bits
(),
trans
_
A
?
1
:
2
));
int
dim_A
=
a_
->
shape
.
size
();
const
int64_t
mat_stride
=
*
as_const_int
(
a_
->
shape
[
dim_A
-
2
]);
const
int64_t
mat_continuous
=
*
as_const_int
(
a_
->
shape
[
dim_A
-
1
]);
results
.
Set
(
a_
,
makeGemmABLayoutSm100
(
mat_stride
,
mat_continuous
,
mat_continuous
,
a_
->
dtype
.
bits
(),
transA
_
?
1
:
2
));
}
{
int
dim_B
=
B
->
shape
.
size
();
const
int64_t
mat_stride
=
*
as_const_int
(
B
->
shape
[
dim_B
-
2
]);
const
int64_t
mat_continuous
=
*
as_const_int
(
B
->
shape
[
dim_B
-
1
]);
int
dim_B
=
b_
->
shape
.
size
();
const
int64_t
mat_stride
=
*
as_const_int
(
b_
->
shape
[
dim_B
-
2
]);
const
int64_t
mat_continuous
=
*
as_const_int
(
b_
->
shape
[
dim_B
-
1
]);
const
int64_t
continuity
=
mat_continuous
;
results
.
Set
(
B
,
results
.
Set
(
b_
,
makeGemmABLayoutSm100
(
mat_stride
,
mat_continuous
,
continuity
,
B
->
dtype
.
bits
(),
trans
_
B
?
2
:
1
));
b_
->
dtype
.
bits
(),
transB
_
?
2
:
1
));
}
{
Layout
res
;
IterVar
i
=
make_itervar
(
"i"
,
M
);
IterVar
j
=
make_itervar
(
"j"
,
N
);
ICHECK
(
M
%
meta
.
atom_m
==
0
);
IterVar
i
=
make_itervar
(
"i"
,
m_
);
IterVar
j
=
make_itervar
(
"j"
,
n_
);
ICHECK
(
m_
%
meta
.
atom_m
==
0
);
PrimExpr
atom_idx
=
FloorDiv
(
i
,
meta
.
atom_m
)
+
FloorDiv
(
j
,
meta
.
atom_n
)
*
(
M
/
meta
.
atom_m
);
FloorDiv
(
j
,
meta
.
atom_n
)
*
(
m_
/
meta
.
atom_m
);
PrimExpr
ai
=
FloorMod
(
i
,
meta
.
atom_m
);
// "ai" means "atom_i"
PrimExpr
aj
=
FloorMod
(
j
,
meta
.
atom_n
);
if
(
meta
.
atom_m
==
128
)
{
...
...
@@ -822,46 +858,46 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T,
}
else
{
ICHECK
(
0
);
}
results
.
Set
(
C
,
res
);
results
.
Set
(
c_
,
res
);
}
}
else
if
(
TargetIsCDNA
(
T
.
target
))
{
ICHECK
(
C
.
scope
()
==
"local.fragment"
)
ICHECK
(
c_
.
scope
()
==
"local.fragment"
)
<<
"CDNA gemm (FMMA) only supports C in local.fragment scope, got "
<<
C
.
scope
();
<<
c_
.
scope
();
if
(
TargetIsDCU
(
T
.
target
))
{
auto
fragment
=
makeGemmFragmentCDCU
(
M
,
N
,
M
/
warp_m
,
N
/
warp_n
,
C
->
dtype
.
bits
());
results
.
Set
(
C
,
fragment
->
BindThreadRange
(
thread_range
));
makeGemmFragmentCDCU
(
m_
,
n_
,
m_
/
warp_m
,
n_
/
warp_n
,
c_
->
dtype
.
bits
());
results
.
Set
(
c_
,
fragment
->
BindThreadRange
(
thread_range
));
}
else
{
auto
fragment
=
makeGemmFragmentCCDNA
(
M
,
N
,
M
/
warp_m
,
N
/
warp_n
,
C
->
dtype
.
bits
());
results
.
Set
(
C
,
fragment
->
BindThreadRange
(
thread_range
));
makeGemmFragmentCCDNA
(
m_
,
n_
,
m_
/
warp_m
,
n_
/
warp_n
,
c_
->
dtype
.
bits
());
results
.
Set
(
c_
,
fragment
->
BindThreadRange
(
thread_range
));
}
if
(
A
.
scope
()
==
"shared"
||
A
.
scope
()
==
"shared.dyn"
)
{
int
dim_A
=
A
->
shape
.
size
();
if
(
a_
.
scope
()
==
"shared"
||
a_
.
scope
()
==
"shared.dyn"
)
{
int
dim_A
=
a_
->
shape
.
size
();
auto
shared_layout
=
makeGemmABLayoutCDNA
(
*
as_const_int
(
A
->
shape
[
dim_A
-
2
]),
*
as_const_int
(
A
->
shape
[
dim_A
-
1
]),
A
->
dtype
.
bits
(),
kPack
);
results
.
Set
(
A
,
shared_layout
);
}
else
if
(
A
.
scope
()
==
"local.fragment"
)
{
auto
fragment
=
makeGemmFragmentACDNA
(
M
,
N
,
K
,
M
/
warp_m
,
N
/
warp_n
,
A
->
dtype
.
bits
(),
kPack
,
trans_A
);
results
.
Set
(
A
,
fragment
->
BindThreadRange
(
thread_range
));
*
as_const_int
(
a_
->
shape
[
dim_A
-
2
]),
*
as_const_int
(
a_
->
shape
[
dim_A
-
1
]),
a_
->
dtype
.
bits
(),
kPack_
);
results
.
Set
(
a_
,
shared_layout
);
}
else
if
(
a_
.
scope
()
==
"local.fragment"
)
{
auto
fragment
=
makeGemmFragmentACDNA
(
m_
,
n_
,
k_
,
m_
/
warp_m
,
n_
/
warp_n
,
a_
->
dtype
.
bits
(),
kPack_
,
transA_
);
results
.
Set
(
a_
,
fragment
->
BindThreadRange
(
thread_range
));
}
else
{
ICHECK
(
0
);
}
if
(
B
.
scope
()
==
"shared"
||
B
.
scope
()
==
"shared.dyn"
)
{
int
dim_B
=
B
->
shape
.
size
();
if
(
b_
.
scope
()
==
"shared"
||
b_
.
scope
()
==
"shared.dyn"
)
{
int
dim_B
=
b_
->
shape
.
size
();
auto
shared_layout
=
makeGemmABLayoutCDNA
(
*
as_const_int
(
B
->
shape
[
dim_B
-
2
]),
*
as_const_int
(
B
->
shape
[
dim_B
-
1
]),
B
->
dtype
.
bits
(),
kPack
);
*
as_const_int
(
b_
->
shape
[
dim_B
-
2
]),
*
as_const_int
(
b_
->
shape
[
dim_B
-
1
]),
b_
->
dtype
.
bits
(),
kPack
_
);
results
.
Set
(
B
,
shared_layout
);
}
else
if
(
B
.
scope
()
==
"local.fragment"
)
{
results
.
Set
(
b_
,
shared_layout
);
}
else
if
(
b_
.
scope
()
==
"local.fragment"
)
{
auto
fragment
=
makeGemmFragmentB
(
M
,
N
,
K
,
M
/
warp_m
,
N
/
warp_n
,
trans
_
B
);
results
.
Set
(
B
,
fragment
->
BindThreadRange
(
thread_range
));
makeGemmFragmentB
(
m_
,
n_
,
k_
,
m_
/
warp_m
,
n_
/
warp_n
,
transB
_
);
results
.
Set
(
b_
,
fragment
->
BindThreadRange
(
thread_range
));
}
else
{
ICHECK
(
0
);
}
...
...
@@ -880,18 +916,17 @@ TIR_REGISTER_TL_OP(Gemm, gemm)
TVM_REGISTER_OP
(
"tl.GemmWarpPolicy"
)
.
set_attr
<
TScriptPrinterName
>
(
"TScriptPrinterName"
,
"GemmWarpPolicy"
);
TVM_FFI_STATIC_INIT_BLOCK
({
TVM_FFI_STATIC_INIT_BLOCK
(
)
{
GemmNode
::
RegisterReflection
();
GemmWarpPolicyNode
::
RegisterReflection
();
namespace
refl
=
tvm
::
ffi
::
reflection
;
refl
::
GlobalDef
().
def
(
"tl.GemmWarpPolicyComputeWarpPartition"
,
[](
GemmWarpPolicy
policy
,
int
M
,
int
N
,
int
block_size
,
Target
target
,
GemmInst
gemm_inst
)
{
policy
->
C
omputeWarpPartition
(
M
,
N
,
block_size
,
target
,
policy
->
c
omputeWarpPartition
(
M
,
N
,
block_size
,
target
,
gemm_inst
);
return
;
});
}
);
}
}
// namespace tl
}
// namespace tvm
src/op/gemm.h
View file @
bbbf4207
...
...
@@ -30,8 +30,7 @@ public:
mutable
int
n_warp
{
0
};
int
policy_type
;
static
constexpr
const
char
*
_type_key
=
"tl.GemmWarpPolicy"
;
TVM_DECLARE_FINAL_OBJECT_INFO
(
GemmWarpPolicyNode
,
Object
);
TVM_FFI_DECLARE_OBJECT_INFO
(
"tl.GemmWarpPolicy"
,
GemmWarpPolicyNode
,
Object
);
static
void
RegisterReflection
()
{
namespace
refl
=
tvm
::
ffi
::
reflection
;
...
...
@@ -41,22 +40,7 @@ public:
.
def_ro
(
"n_warp"
,
&
GemmWarpPolicyNode
::
n_warp
);
}
bool
SEqualReduce
(
const
GemmWarpPolicyNode
*
other
,
SEqualReducer
equal
)
const
{
return
equal
(
policy_type
,
other
->
policy_type
)
&&
equal
(
m_warp
,
other
->
m_warp
)
&&
equal
(
n_warp
,
other
->
n_warp
);
}
void
SHashReduce
(
SHashReducer
hash_reduce
)
const
{
hash_reduce
(
policy_type
);
hash_reduce
(
m_warp
);
hash_reduce
(
n_warp
);
}
static
constexpr
bool
_type_has_method_sequal_reduce
=
true
;
static
constexpr
bool
_type_has_method_shash_reduce
=
true
;
std
::
pair
<
int
,
int
>
ComputeWarpPartition
(
int
M
,
int
N
,
int
block_size
,
std
::
pair
<
int
,
int
>
computeWarpPartition
(
int
M
,
int
N
,
int
block_size
,
Target
target
,
GemmInst
gemm_inst
)
const
;
...
...
@@ -74,22 +58,23 @@ public:
class
GemmWarpPolicy
:
public
ObjectRef
{
public:
TVM_DEFINE_OBJECT_REF_METHODS
(
GemmWarpPolicy
,
ObjectRef
,
GemmWarpPolicyNode
);
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE
(
GemmWarpPolicy
,
ObjectRef
,
GemmWarpPolicyNode
);
explicit
GemmWarpPolicy
(
GemmWarpPolicyType
policy_type
)
{
auto
node
=
make_object
<
GemmWarpPolicyNode
>
();
auto
node
=
tvm
::
ffi
::
make_object
<
GemmWarpPolicyNode
>
();
node
->
policy_type
=
(
int
)
policy_type
;
data_
=
std
::
move
(
node
);
}
explicit
GemmWarpPolicy
(
int
policy_type
)
{
auto
node
=
make_object
<
GemmWarpPolicyNode
>
();
auto
node
=
tvm
::
ffi
::
make_object
<
GemmWarpPolicyNode
>
();
node
->
policy_type
=
policy_type
;
data_
=
std
::
move
(
node
);
}
explicit
GemmWarpPolicy
(
int
m_warp
,
int
n_warp
)
{
auto
node
=
make_object
<
GemmWarpPolicyNode
>
();
auto
node
=
tvm
::
ffi
::
make_object
<
GemmWarpPolicyNode
>
();
node
->
m_warp
=
m_warp
;
node
->
n_warp
=
n_warp
;
node
->
policy_type
=
(
int
)
GemmWarpPolicyType
::
kFree
;
...
...
@@ -99,89 +84,48 @@ public:
class
GemmNode
:
public
TileOperatorNode
{
public:
bool
C
heckW
GMMA
()
const
;
tir
::
Buffer
A
,
B
,
C
;
//
pointer to the
A, B
,
C
PrimExpr
Aptr
,
Bptr
,
Cptr
;
bool
trans
_
A
,
trans
_
B
;
int
M
,
N
,
K
;
int
stride
_
A
,
stride
_
B
;
int
offset
_
A
,
offset
_
B
;
PrimExpr
clear
_a
ccum
=
const_false
();
bool
c
heckW
gmma
()
const
;
tir
::
Buffer
a_
,
b_
,
c_
;
//
BufferRegion for
A, B
and
C
BufferRegion
aRegion_
,
bRegion_
,
cRegion_
;
bool
transA
_
,
transB
_
;
int
m_
,
n_
,
k_
;
int
strideA
_
,
strideB
_
;
int
offsetA
_
,
offsetB
_
;
PrimExpr
clear
A
ccum
_
=
const_false
();
// k_pack please ref to bitblas/tl/mfma_macro_generator.py::k_pack
// only will be enabled under cdna mfma instructions
int
kPack
=
1
;
int
wg_wait
=
0
;
PrimExpr
mbarptr
;
std
::
optional
<
tir
::
Buffer
>
mbar
;
// mbar is optional, only used for TCGEN5MMA
Array
<
PrimExpr
>
C_coords
;
mutable
GemmWarpPolicy
policy
;
static
constexpr
const
char
*
_type_key
=
"tl.Gemm"
;
TVM_DECLARE_FINAL_OBJECT_INFO
(
GemmNode
,
TileOperatorNode
);
int
kPack_
=
1
;
int
wgWait_
=
0
;
PrimExpr
mbarPtr_
;
std
::
optional
<
tir
::
Buffer
>
mbar_
;
// mbar is optional, only used for TCGEN5MMA
Array
<
PrimExpr
>
cCoords_
;
mutable
GemmWarpPolicy
policy_
;
TVM_FFI_DECLARE_OBJECT_INFO_FINAL
(
"tl.Gemm"
,
GemmNode
,
TileOperatorNode
);
static
void
RegisterReflection
()
{
namespace
refl
=
tvm
::
ffi
::
reflection
;
refl
::
ObjectDef
<
GemmNode
>
()
.
def_ro
(
"A"
,
&
GemmNode
::
A
)
.
def_ro
(
"B"
,
&
GemmNode
::
B
)
.
def_ro
(
"C"
,
&
GemmNode
::
C
)
.
def_ro
(
"Aptr"
,
&
GemmNode
::
Aptr
)
.
def_ro
(
"Bptr"
,
&
GemmNode
::
Bptr
)
.
def_ro
(
"Cptr"
,
&
GemmNode
::
Cptr
)
.
def_ro
(
"trans_A"
,
&
GemmNode
::
trans_A
)
.
def_ro
(
"trans_B"
,
&
GemmNode
::
trans_B
)
.
def_ro
(
"M"
,
&
GemmNode
::
M
)
.
def_ro
(
"N"
,
&
GemmNode
::
N
)
.
def_ro
(
"K"
,
&
GemmNode
::
K
)
.
def_ro
(
"stride_A"
,
&
GemmNode
::
stride_A
)
.
def_ro
(
"stride_B"
,
&
GemmNode
::
stride_B
)
.
def_ro
(
"offset_A"
,
&
GemmNode
::
offset_A
)
.
def_ro
(
"offset_B"
,
&
GemmNode
::
offset_B
)
.
def_ro
(
"clear_accum"
,
&
GemmNode
::
clear_accum
)
.
def_ro
(
"kPack"
,
&
GemmNode
::
kPack
)
.
def_ro
(
"wg_wait"
,
&
GemmNode
::
wg_wait
)
.
def_ro
(
"policy"
,
&
GemmNode
::
policy
);
}
bool
SEqualReduce
(
const
GemmNode
*
other
,
SEqualReducer
equal
)
const
{
return
equal
(
A
,
other
->
A
)
&&
equal
(
B
,
other
->
B
)
&&
equal
(
C
,
other
->
C
)
&&
equal
(
Aptr
,
other
->
Aptr
)
&&
equal
(
Bptr
,
other
->
Bptr
)
&&
equal
(
Cptr
,
other
->
Cptr
)
&&
equal
(
trans_A
,
other
->
trans_A
)
&&
equal
(
trans_B
,
other
->
trans_B
)
&&
equal
(
M
,
other
->
M
)
&&
equal
(
N
,
other
->
N
)
&&
equal
(
K
,
other
->
K
)
&&
equal
(
stride_A
,
other
->
stride_A
)
&&
equal
(
stride_B
,
other
->
stride_B
)
&&
equal
(
offset_A
,
other
->
offset_A
)
&&
equal
(
offset_B
,
other
->
offset_B
)
&&
equal
(
clear_accum
,
other
->
clear_accum
)
&&
equal
(
kPack
,
other
->
kPack
)
&&
equal
(
wg_wait
,
other
->
wg_wait
)
&&
equal
(
policy
,
other
->
policy
);
}
void
SHashReduce
(
SHashReducer
hash_reduce
)
const
{
hash_reduce
(
A
);
hash_reduce
(
B
);
hash_reduce
(
C
);
hash_reduce
(
Aptr
);
hash_reduce
(
Bptr
);
hash_reduce
(
Cptr
);
hash_reduce
(
trans_A
);
hash_reduce
(
trans_B
);
hash_reduce
(
M
);
hash_reduce
(
N
);
hash_reduce
(
K
);
hash_reduce
(
stride_A
);
hash_reduce
(
stride_B
);
hash_reduce
(
offset_A
);
hash_reduce
(
offset_B
);
hash_reduce
(
clear_accum
);
hash_reduce
(
kPack
);
hash_reduce
(
wg_wait
);
hash_reduce
(
policy
);
.
def_ro
(
"a"
,
&
GemmNode
::
a_
)
.
def_ro
(
"b"
,
&
GemmNode
::
b_
)
.
def_ro
(
"c"
,
&
GemmNode
::
c_
)
.
def_ro
(
"aRegion"
,
&
GemmNode
::
aRegion_
)
.
def_ro
(
"bRegion"
,
&
GemmNode
::
bRegion_
)
.
def_ro
(
"cRegion"
,
&
GemmNode
::
cRegion_
)
.
def_ro
(
"transA"
,
&
GemmNode
::
transA_
)
.
def_ro
(
"transB"
,
&
GemmNode
::
transB_
)
.
def_ro
(
"m"
,
&
GemmNode
::
m_
)
.
def_ro
(
"n"
,
&
GemmNode
::
n_
)
.
def_ro
(
"k"
,
&
GemmNode
::
k_
)
.
def_ro
(
"strideA"
,
&
GemmNode
::
strideA_
)
.
def_ro
(
"strideB"
,
&
GemmNode
::
strideB_
)
.
def_ro
(
"offsetA"
,
&
GemmNode
::
offsetA_
)
.
def_ro
(
"offsetB"
,
&
GemmNode
::
offsetB_
)
.
def_ro
(
"clearAccum"
,
&
GemmNode
::
clearAccum_
)
.
def_ro
(
"kPack"
,
&
GemmNode
::
kPack_
)
.
def_ro
(
"wgWait"
,
&
GemmNode
::
wgWait_
)
.
def_ro
(
"policy"
,
&
GemmNode
::
policy_
);
}
static
constexpr
bool
_type_has_method_sequal_reduce
=
true
;
static
constexpr
bool
_type_has_method_shash_reduce
=
true
;
Stmt
Lower
(
const
LowerArgs
&
T
,
arith
::
Analyzer
*
analyzer
)
const
override
;
LayoutMap
InferLayout
(
const
LayoutInferArgs
&
T
,
...
...
@@ -190,16 +134,16 @@ public:
TileOperator
Clone
()
const
;
private:
GemmInst
G
etGemmInst
(
int
block_size
,
Target
target
)
const
;
bool
A
llowT
CGEN5MMA
(
Target
target
)
const
;
bool
A
llowW
GMMA
(
int
block_size
,
Target
target
)
const
;
GemmInst
g
etGemmInst
(
int
block_size
,
Target
target
)
const
;
bool
a
llowT
cgen5Mma
(
Target
target
)
const
;
bool
a
llowW
gmma
(
int
block_size
,
Target
target
)
const
;
mutable
bool
completed_
=
false
;
};
class
Gemm
:
public
TileOperator
{
public:
TVM_DEFINE_OBJECT_REF_METHODS
(
Gemm
,
TileOperator
,
GemmNode
);
TVM_
FFI_
DEFINE_OBJECT_REF_METHODS
_NULLABLE
(
Gemm
,
TileOperator
,
GemmNode
);
TVM_DLL
Gemm
(
Array
<
PrimExpr
>
args
,
BufferMap
vmap
);
static
const
Op
&
Get
();
};
...
...
src/op/gemm_py.cc
View file @
bbbf4207
...
...
@@ -12,13 +12,101 @@
#include <tvm/tir/transform.h>
#include "../target/utils.h"
#include "tvm/ffi/string.h"
#include "region.h"
#include "tcgen5_meta.h"
namespace
tvm
{
namespace
tl
{
using
namespace
tir
;
// Normalize a GEMM argument (BufferRegion/BufferLoad/tvm_access_ptr/tl.region)
// to BufferRegion
static
BufferRegion
NormalizeToBufferRegion
(
const
PrimExpr
&
arg
,
const
BufferMap
&
vmap
)
{
// Case 1: Already a BufferRegion
if
(
arg
->
IsInstance
<
BufferRegionNode
>
())
{
return
Downcast
<
BufferRegion
>
(
arg
);
}
// Case 2: BufferLoad — convert indices to ranges (Ramp -> lanes, else
// extent=1)
if
(
const
auto
*
load
=
arg
.
as
<
BufferLoadNode
>
())
{
Array
<
Range
>
ranges
;
for
(
const
PrimExpr
&
index
:
load
->
indices
)
{
if
(
const
auto
*
ramp
=
index
.
as
<
RampNode
>
())
{
ICHECK
(
ramp
->
stride
.
as
<
IntImmNode
>
())
<<
"Ramp stride must be IntImm"
;
ICHECK_EQ
(
ramp
->
stride
.
as
<
IntImmNode
>
()
->
value
,
1
)
<<
"Only stride-1 Ramp is supported in GEMM region conversion"
;
ICHECK
(
ramp
->
lanes
.
as
<
IntImmNode
>
())
<<
"Scalable vector lanes not supported in GEMM region conversion"
;
ranges
.
push_back
(
Range
::
FromMinExtent
(
ramp
->
base
,
ramp
->
lanes
));
}
else
{
ranges
.
push_back
(
Range
::
FromMinExtent
(
index
,
1
));
}
}
return
BufferRegion
(
load
->
buffer
,
ranges
);
}
// Case 3: Call nodes
if
(
const
auto
*
call
=
arg
.
as
<
CallNode
>
())
{
// tl.region(...) — reconstruct via RegionOp
if
(
call
->
op
.
same_as
(
RegionOp
::
Get
()))
{
RegionOp
region
(
call
->
args
,
vmap
);
return
BufferRegion
(
region
->
GetBuffer
(),
region
->
GetRanges
());
}
// builtin.tvm_access_ptr(...) — map var to Buffer and take full region
if
(
call
->
op
.
same_as
(
builtin
::
tvm_access_ptr
()))
{
Var
var
=
Downcast
<
Var
>
(
call
->
args
[
1
]);
Buffer
buf
=
vmap
.
at
(
var
);
Array
<
Range
>
ranges
;
for
(
PrimExpr
extent
:
buf
->
shape
)
{
ranges
.
push_back
(
Range
(
IntImm
(
extent
->
dtype
,
0
),
extent
));
}
return
BufferRegion
(
buf
,
ranges
);
}
}
LOG
(
FATAL
)
<<
"Unsupported GEMM argument for BufferRegion: "
<<
arg
;
throw
;
// Unreachable, keeps compiler happy
}
// Build a tvm_access_ptr(handle) to the start of the 2D tile within a
// BufferRegion. Offset is computed from all but the last two dimensions; extent
// is the product of the last two extents. rw_mask: 1=read, 2=write,
// 3=readwrite.
static
PrimExpr
MakeAccessPtrFromRegion
(
const
BufferRegion
&
region
,
int
rw_mask
)
{
Buffer
buf
=
region
->
buffer
;
int
ndim
=
static_cast
<
int
>
(
buf
->
shape
.
size
());
ICHECK
(
ndim
>=
2
)
<<
"GEMM expects buffers with at least 2 dims"
;
// Compute row-major strides
std
::
vector
<
PrimExpr
>
strides
(
ndim
);
PrimExpr
one
=
make_const
(
buf
->
shape
[
0
].
dtype
(),
1
);
PrimExpr
cur
=
one
;
for
(
int
i
=
ndim
-
1
;
i
>=
0
;
--
i
)
{
strides
[
i
]
=
cur
;
cur
=
cur
*
buf
->
shape
[
i
];
}
// Offset: sum_{i in [0..ndim-3]} min_i * stride_i
PrimExpr
offset
=
make_const
(
buf
->
shape
[
0
].
dtype
(),
0
);
for
(
int
i
=
0
;
i
<
ndim
-
2
;
++
i
)
{
offset
=
offset
+
region
->
region
[
i
]
->
min
*
strides
[
i
];
}
// Extent: last two extents product (elements)
PrimExpr
extent
=
region
->
region
[
ndim
-
2
]
->
extent
*
region
->
region
[
ndim
-
1
]
->
extent
;
// ptype and return handle
PrimExpr
ptype
=
tir
::
TypeAnnotation
(
buf
->
dtype
);
Array
<
PrimExpr
>
acc_args
{
ptype
,
buf
->
data
,
offset
,
extent
,
IntImm
(
DataType
::
Int
(
32
),
rw_mask
)};
return
Call
(
DataType
::
Handle
(),
builtin
::
tvm_access_ptr
(),
acc_args
);
}
/**
* @brief Construct a Gemm operator from serialized TL arguments and a buffer
* map.
...
...
@@ -48,34 +136,43 @@ using namespace tir;
* performed here.
*/
GemmPy
::
GemmPy
(
Array
<
PrimExpr
>
args
,
BufferMap
vmap
)
{
ObjectPtr
<
GemmPyNode
>
node
=
make_object
<
GemmPyNode
>
();
node
->
Aptr
=
args
[
0
];
node
->
Bptr
=
args
[
1
];
node
->
Cptr
=
args
[
2
];
node
->
A
=
vmap
[
GetVarFromAccessPtr
(
node
->
Aptr
)];
node
->
B
=
vmap
[
GetVarFromAccessPtr
(
node
->
Bptr
)];
node
->
C
=
vmap
[
GetVarFromAccessPtr
(
node
->
Cptr
)];
node
->
trans_A
=
args
[
3
].
as
<
Bool
>
().
value
();
node
->
trans_B
=
args
[
4
].
as
<
Bool
>
().
value
();
node
->
M
=
args
[
5
].
as
<
IntImm
>
().
value
()
->
value
;
node
->
N
=
args
[
6
].
as
<
IntImm
>
().
value
()
->
value
;
node
->
K
=
args
[
7
].
as
<
IntImm
>
().
value
()
->
value
;
node
->
policy
=
GemmWarpPolicy
(
args
[
8
].
as
<
IntImm
>
().
value
()
->
value
);
node
->
clear_accum
=
args
[
9
].
as
<
PrimExpr
>
().
value
();
node
->
stride_A
=
args
[
10
].
as
<
IntImm
>
().
value
()
->
value
;
node
->
stride_B
=
args
[
11
].
as
<
IntImm
>
().
value
()
->
value
;
node
->
offset_A
=
args
[
12
].
as
<
IntImm
>
().
value
()
->
value
;
node
->
offset_B
=
args
[
13
].
as
<
IntImm
>
().
value
()
->
value
;
ObjectPtr
<
GemmPyNode
>
node
=
tvm
::
ffi
::
make_object
<
GemmPyNode
>
();
node
->
aRegion_
=
NormalizeToBufferRegion
(
args
[
0
],
vmap
);
node
->
bRegion_
=
NormalizeToBufferRegion
(
args
[
1
],
vmap
);
node
->
cRegion_
=
NormalizeToBufferRegion
(
args
[
2
],
vmap
);
node
->
a_
=
node
->
aRegion_
->
buffer
;
node
->
b_
=
node
->
bRegion_
->
buffer
;
node
->
c_
=
node
->
cRegion_
->
buffer
;
node
->
transA_
=
args
[
3
].
as
<
Bool
>
().
value
();
node
->
transB_
=
args
[
4
].
as
<
Bool
>
().
value
();
node
->
m_
=
args
[
5
].
as
<
IntImm
>
().
value
()
->
value
;
node
->
n_
=
args
[
6
].
as
<
IntImm
>
().
value
()
->
value
;
node
->
k_
=
args
[
7
].
as
<
IntImm
>
().
value
()
->
value
;
node
->
policy_
=
GemmWarpPolicy
(
args
[
8
].
as
<
IntImm
>
().
value
()
->
value
);
node
->
clearAccum_
=
args
[
9
].
as
<
PrimExpr
>
().
value
();
node
->
strideA_
=
args
[
10
].
as
<
IntImm
>
().
value
()
->
value
;
node
->
strideB_
=
args
[
11
].
as
<
IntImm
>
().
value
()
->
value
;
node
->
offsetA_
=
args
[
12
].
as
<
IntImm
>
().
value
()
->
value
;
node
->
offsetB_
=
args
[
13
].
as
<
IntImm
>
().
value
()
->
value
;
if
(
args
.
size
()
>
14
)
{
node
->
kPack
=
args
[
14
].
as
<
IntImm
>
().
value
()
->
value
;
if
(
node
->
kPack
!=
1
&&
node
->
kPack
!=
2
)
{
node
->
kPack
_
=
args
[
14
].
as
<
IntImm
>
().
value
()
->
value
;
if
(
node
->
kPack
_
!=
1
&&
node
->
kPack
_
!=
2
)
{
ICHECK
(
false
)
<<
"kPack must be 1 or 2"
;
}
}
if
(
args
.
size
()
>
15
)
{
node
->
wg_wait
=
args
[
15
].
as
<
IntImm
>
().
value
()
->
value
;
node
->
wgWait_
=
args
[
15
].
as
<
IntImm
>
().
value
()
->
value
;
}
node
->
mbarPtr_
=
args
[
16
];
if
(
node
->
mbarPtr_
.
as
<
CallNode
>
())
{
node
->
mbar_
=
vmap
[
GetVarFromAccessPtr
(
node
->
mbarPtr_
)];
}
else
{
node
->
mbar_
=
std
::
nullopt
;
}
node
->
cCoords_
=
Array
<
PrimExpr
>
(
{
args
[
17
].
as
<
PrimExpr
>
().
value
(),
args
[
18
].
as
<
PrimExpr
>
().
value
()});
data_
=
std
::
move
(
node
);
}
...
...
@@ -88,20 +185,41 @@ GemmPy::GemmPy(Array<PrimExpr> args, BufferMap vmap) {
* @return TileOperator A Gemm operator that owns a copy of this node.
*/
TileOperator
GemmPyNode
::
Clone
()
const
{
auto
op
=
make_object
<
GemmPyNode
>
(
*
this
);
auto
op
=
tvm
::
ffi
::
make_object
<
GemmPyNode
>
(
*
this
);
return
GemmPy
(
op
);
}
GemmInst
GemmPyNode
::
GetGemmInst
(
int
block_size
,
Target
target
)
const
{
bool
GemmPyNode
::
allowTcgen5Mma
(
Target
target
)
const
{
return
TargetIsSm100
(
target
)
&&
((
a_
.
scope
()
==
"shared.dyn"
||
a_
.
scope
()
==
"shared"
||
a_
.
scope
()
==
"shared.tmem"
)
&&
(
b_
.
scope
()
==
"shared.dyn"
||
b_
.
scope
()
==
"shared"
)
&&
c_
.
scope
()
==
"shared.tmem"
)
&&
GetTCGEN5MMAMeta
(
m_
,
n_
,
k_
,
a_
->
dtype
,
c_
->
dtype
).
first
;
}
bool
GemmPyNode
::
allowWgmma
(
int
block_size
,
Target
target
)
const
{
tvm
::
transform
::
PassContext
ctxt
=
tvm
::
transform
::
PassContext
::
Current
();
int
warp_size
=
TargetGetWarpSize
(
target
);
int
num_warps
=
block_size
/
warp_size
;
bool
allow_wgmma
=
TargetIsHopper
(
target
)
&&
(
this
->
M
>=
64
)
&&
(
num_warps
%
4
==
0
)
&&
CheckWGMMA
();
if
(
allow_wgmma
)
{
return
!
ctxt
->
GetConfig
(
kDisableWGMMA
,
Optional
<
Bool
>
()).
value_or
(
false
)
&&
TargetIsHopper
(
target
)
&&
(
this
->
m_
>=
64
)
&&
(
num_warps
%
4
==
0
)
&&
checkWgmma
();
}
GemmInst
GemmPyNode
::
getGemmInst
(
int
block_size
,
Target
target
)
const
{
bool
allow_tcgen5mma
=
allowTcgen5Mma
(
target
);
bool
allow_wgmma
=
allowWgmma
(
block_size
,
target
);
if
(
allow_tcgen5mma
)
{
return
GemmInst
::
kTCGEN5MMA
;
}
else
if
(
allow_wgmma
)
{
return
GemmInst
::
kWGMMA
;
}
else
if
(
TargetIsCDNA
(
target
))
{
return
GemmInst
::
kMFMA
;
}
else
if
(
TargetIsCuda
(
target
))
{
}
else
if
(
TargetIsVolta
(
target
)
||
TargetIsAmpere
(
target
)
||
TargetIsTuring
(
target
)
||
TargetIsHopper
(
target
)
||
TargetIsSm100
(
target
))
{
return
GemmInst
::
kMMA
;
}
else
{
ICHECK
(
0
)
<<
"Unsupported target for gemm: "
<<
target
->
str
();
...
...
@@ -140,51 +258,52 @@ GemmInst GemmPyNode::GetGemmInst(int block_size, Target target) const {
* @return true if WGMMA is supported for the current buffers, dtypes, and
* transpose/shape constraints; false otherwise.
*/
bool
GemmPyNode
::
C
heckW
GMMA
()
const
{
if
(
B
.
scope
()
!=
"shared.dyn"
&&
B
.
scope
()
!=
"shared"
)
{
bool
GemmPyNode
::
c
heckW
gmma
()
const
{
if
(
b_
.
scope
()
!=
"shared.dyn"
&&
b_
.
scope
()
!=
"shared"
)
{
return
false
;
}
if
(
C
->
dtype
==
DataType
::
Float
(
16
))
{
if
(
A
->
dtype
==
DataType
::
Float
(
16
)
&&
B
->
dtype
==
DataType
::
Float
(
16
))
return
K
%
16
==
0
;
else
if
(
A
->
dtype
.
is_float8_e4m3
()
&&
B
->
dtype
.
is_float8_e4m3
())
return
(
!
trans
_
A
)
&&
trans
_
B
&&
K
%
32
==
0
;
else
if
(
A
->
dtype
.
is_float8_e4m3
()
&&
B
->
dtype
.
is_float8_e5m2
())
return
(
!
trans
_
A
)
&&
trans
_
B
&&
K
%
32
==
0
;
else
if
(
A
->
dtype
.
is_float8_e5m2
()
&&
B
->
dtype
.
is_float8_e4m3
())
return
(
!
trans
_
A
)
&&
trans
_
B
&&
K
%
32
==
0
;
else
if
(
A
->
dtype
.
is_float8_e5m2
()
&&
B
->
dtype
.
is_float8_e5m2
())
return
(
!
trans
_
A
)
&&
trans
_
B
&&
K
%
32
==
0
;
if
(
c_
->
dtype
==
DataType
::
Float
(
16
))
{
if
(
a_
->
dtype
==
DataType
::
Float
(
16
)
&&
b_
->
dtype
==
DataType
::
Float
(
16
))
return
k_
%
16
==
0
;
else
if
(
a_
->
dtype
.
is_float8_e4m3
()
&&
b_
->
dtype
.
is_float8_e4m3
())
return
(
!
transA
_
)
&&
transB
_
&&
k_
%
32
==
0
;
else
if
(
a_
->
dtype
.
is_float8_e4m3
()
&&
b_
->
dtype
.
is_float8_e5m2
())
return
(
!
transA
_
)
&&
transB
_
&&
k_
%
32
==
0
;
else
if
(
a_
->
dtype
.
is_float8_e5m2
()
&&
b_
->
dtype
.
is_float8_e4m3
())
return
(
!
transA
_
)
&&
transB
_
&&
k_
%
32
==
0
;
else
if
(
a_
->
dtype
.
is_float8_e5m2
()
&&
b_
->
dtype
.
is_float8_e5m2
())
return
(
!
transA
_
)
&&
transB
_
&&
k_
%
32
==
0
;
else
return
false
;
}
else
if
(
C
->
dtype
==
DataType
::
Float
(
32
))
{
if
(
A
->
dtype
==
DataType
::
Float
(
16
)
&&
B
->
dtype
==
DataType
::
Float
(
16
))
return
K
%
16
==
0
;
else
if
(
A
->
dtype
==
DataType
::
BFloat
(
16
)
&&
B
->
dtype
==
DataType
::
BFloat
(
16
))
return
K
%
16
==
0
;
else
if
(
A
->
dtype
==
DataType
::
Float
(
32
)
&&
B
->
dtype
==
DataType
::
Float
(
32
))
return
(
!
trans_A
)
&&
trans_B
&&
K
%
8
==
0
;
else
if
(
A
->
dtype
.
is_float8_e4m3
()
&&
B
->
dtype
.
is_float8_e4m3
())
return
(
!
trans_A
)
&&
trans_B
&&
K
%
32
==
0
;
else
if
(
A
->
dtype
.
is_float8_e4m3
()
&&
B
->
dtype
.
is_float8_e5m2
())
return
(
!
trans_A
)
&&
trans_B
&&
K
%
32
==
0
;
else
if
(
A
->
dtype
.
is_float8_e5m2
()
&&
B
->
dtype
.
is_float8_e4m3
())
return
(
!
trans_A
)
&&
trans_B
&&
K
%
32
==
0
;
else
if
(
A
->
dtype
.
is_float8_e5m2
()
&&
B
->
dtype
.
is_float8_e5m2
())
return
(
!
trans_A
)
&&
trans_B
&&
K
%
32
==
0
;
}
else
if
(
c_
->
dtype
==
DataType
::
Float
(
32
))
{
if
(
a_
->
dtype
==
DataType
::
Float
(
16
)
&&
b_
->
dtype
==
DataType
::
Float
(
16
))
return
k_
%
16
==
0
;
else
if
(
a_
->
dtype
==
DataType
::
BFloat
(
16
)
&&
b_
->
dtype
==
DataType
::
BFloat
(
16
))
return
k_
%
16
==
0
;
else
if
(
a_
->
dtype
==
DataType
::
Float
(
32
)
&&
b_
->
dtype
==
DataType
::
Float
(
32
))
return
(
!
transA_
)
&&
transB_
&&
k_
%
8
==
0
;
else
if
(
a_
->
dtype
.
is_float8_e4m3
()
&&
b_
->
dtype
.
is_float8_e4m3
())
return
(
!
transA_
)
&&
transB_
&&
k_
%
32
==
0
;
else
if
(
a_
->
dtype
.
is_float8_e4m3
()
&&
b_
->
dtype
.
is_float8_e5m2
())
return
(
!
transA_
)
&&
transB_
&&
k_
%
32
==
0
;
else
if
(
a_
->
dtype
.
is_float8_e5m2
()
&&
b_
->
dtype
.
is_float8_e4m3
())
return
(
!
transA_
)
&&
transB_
&&
k_
%
32
==
0
;
else
if
(
a_
->
dtype
.
is_float8_e5m2
()
&&
b_
->
dtype
.
is_float8_e5m2
())
return
(
!
transA_
)
&&
transB_
&&
k_
%
32
==
0
;
else
return
false
;
}
else
if
(
C
->
dtype
==
DataType
::
Int
(
32
))
{
if
(
A
->
dtype
==
DataType
::
Int
(
8
)
&&
B
->
dtype
==
DataType
::
Int
(
8
))
return
(
!
trans
_
A
)
&&
trans
_
B
&&
K
%
32
==
0
;
else
if
(
A
->
dtype
==
DataType
::
Int
(
8
)
&&
B
->
dtype
==
DataType
::
UInt
(
8
))
return
(
!
trans
_
A
)
&&
trans
_
B
&&
K
%
32
==
0
;
else
if
(
A
->
dtype
==
DataType
::
UInt
(
8
)
&&
B
->
dtype
==
DataType
::
Int
(
8
))
return
(
!
trans
_
A
)
&&
trans
_
B
&&
K
%
32
==
0
;
else
if
(
A
->
dtype
==
DataType
::
UInt
(
8
)
&&
B
->
dtype
==
DataType
::
UInt
(
8
))
return
(
!
trans
_
A
)
&&
trans
_
B
&&
K
%
32
==
0
;
}
else
if
(
c_
->
dtype
==
DataType
::
Int
(
32
))
{
if
(
a_
->
dtype
==
DataType
::
Int
(
8
)
&&
b_
->
dtype
==
DataType
::
Int
(
8
))
return
(
!
transA
_
)
&&
transB
_
&&
k_
%
32
==
0
;
else
if
(
a_
->
dtype
==
DataType
::
Int
(
8
)
&&
b_
->
dtype
==
DataType
::
UInt
(
8
))
return
(
!
transA
_
)
&&
transB
_
&&
k_
%
32
==
0
;
else
if
(
a_
->
dtype
==
DataType
::
UInt
(
8
)
&&
b_
->
dtype
==
DataType
::
Int
(
8
))
return
(
!
transA
_
)
&&
transB
_
&&
k_
%
32
==
0
;
else
if
(
a_
->
dtype
==
DataType
::
UInt
(
8
)
&&
b_
->
dtype
==
DataType
::
UInt
(
8
))
return
(
!
transA
_
)
&&
transB
_
&&
k_
%
32
==
0
;
else
return
false
;
}
else
{
...
...
@@ -208,8 +327,8 @@ bool GemmPyNode::CheckWGMMA() const {
*/
static
int
GetArchInt
(
Target
target
)
{
int
arch_int
=
0
;
auto
s
=
target
->
GetAttr
<
String
>
(
"arch"
);
ICHECK
(
s
.
defined
());
auto
s
=
target
->
GetAttr
<
tvm
::
ffi
::
String
>
(
"arch"
);
ICHECK
(
s
.
has_value
());
std
::
string
arch
=
s
.
value
();
if
(
arch
.
rfind
(
"sm_"
,
0
)
==
0
)
{
arch_int
=
std
::
stoi
(
arch
.
substr
(
3
));
...
...
@@ -221,18 +340,19 @@ static int GetArchInt(Target target) {
Stmt
GemmPyNode
::
Lower
(
const
LowerArgs
&
T
,
arith
::
Analyzer
*
analyzer
)
const
{
auto
block_size
=
*
as_const_int
(
T
.
thread_bounds
->
extent
);
GemmInst
gemm_inst
=
G
etGemmInst
(
block_size
,
T
.
target
);
GemmInst
gemm_inst
=
g
etGemmInst
(
block_size
,
T
.
target
);
auto
[
warp_m
,
warp_n
]
=
policy
->
C
omputeWarpPartition
(
M
,
N
,
block_size
,
T
.
target
,
gemm_inst
);
policy
_
->
c
omputeWarpPartition
(
m_
,
n_
,
block_size
,
T
.
target
,
gemm_inst
);
if
(
const
auto
f
=
ffi
::
Function
::
GetGlobal
(
"tl.gemm_py.lower"
))
{
auto
prim_func
=
Downcast
<
PrimFunc
>
((
*
f
)(
GetRef
<
GemmPy
>
(
this
),
T
.
layout_map
,
T
.
target
,
T
.
thread_bounds
,
T
.
thread_var
));
Downcast
<
PrimFunc
>
((
*
f
)(
tvm
::
ffi
::
GetRef
<
GemmPy
>
(
this
),
T
.
layout_map
,
T
.
target
,
T
.
thread_bounds
,
T
.
thread_var
));
ICHECK
(
prim_func
->
attrs
.
defined
());
auto
global_symbol
=
prim_func
->
attrs
.
GetAttr
<
String
>
(
"global_symbol"
);
ICHECK
(
global_symbol
.
defined
());
auto
global_symbol
=
prim_func
->
attrs
.
GetAttr
<
tvm
::
ffi
::
String
>
(
"global_symbol"
);
ICHECK
(
global_symbol
.
has_value
());
if
(
prim_func
->
body
.
as
<
BlockRealizeNode
>
())
{
BlockRealize
block_realize
=
Downcast
<
BlockRealize
>
(
prim_func
->
body
);
auto
block
=
block_realize
->
block
;
...
...
@@ -265,7 +385,15 @@ LayoutMap GemmPyNode::InferLayout(const LayoutInferArgs &T,
if
(
const
auto
f
=
ffi
::
Function
::
GetGlobal
(
"tl.gemm_py.infer_layout"
))
{
results
=
Downcast
<
LayoutMap
>
(
(
*
f
)(
GetRef
<
GemmPy
>
(
this
),
T
.
target
,
T
.
thread_bounds
));
(
*
f
)(
tvm
::
ffi
::
GetRef
<
GemmPy
>
(
this
),
T
.
target
,
T
.
thread_bounds
));
// Bind all fragment layouts with the provided thread range
for
(
auto
kv
:
results
)
{
const
Buffer
&
buf
=
kv
.
first
;
const
Layout
&
layout
=
kv
.
second
;
if
(
auto
frag
=
layout
.
as
<
Fragment
>
())
{
results
.
Set
(
buf
,
frag
.
value
()
->
BindThreadRange
(
T
.
thread_bounds
));
}
}
}
else
{
LOG
(
FATAL
)
<<
"No infer layout function found for gemm_py"
;
}
...
...
@@ -279,15 +407,41 @@ TIR_REGISTER_TL_OP(GemmPy, gemm_py)
.
set_attr
<
TCallEffectKind
>
(
"TCallEffectKind"
,
Integer
(
CallEffectKind
::
kOpaque
));
TVM_FFI_STATIC_INIT_BLOCK
({
GemmPyNode
::
RegisterReflection
();
}
);
TVM_FFI_STATIC_INIT_BLOCK
(
)
{
GemmPyNode
::
RegisterReflection
();
}
TVM_FFI_STATIC_INIT_BLOCK
({
TVM_FFI_STATIC_INIT_BLOCK
(
)
{
namespace
refl
=
tvm
::
ffi
::
reflection
;
refl
::
GlobalDef
().
def
(
"tl.GemmPyGemmInst"
,
[](
GemmPy
gemm_py
,
int
block_size
,
Target
target
)
{
return
gemm_py
->
GetGemmInst
(
block_size
,
target
);
return
gemm_py
->
getGemmInst
(
block_size
,
target
);
});
}
TVM_FFI_STATIC_INIT_BLOCK
()
{
namespace
refl
=
tvm
::
ffi
::
reflection
;
refl
::
GlobalDef
().
def
(
"tl.get_tcgen5_mma_meta"
,
[](
int
M
,
int
N
,
int
K
,
DataType
ab_dtype
,
DataType
c_dtype
)
{
auto
[
success
,
meta
]
=
GetTCGEN5MMAMeta
(
M
,
N
,
K
,
ab_dtype
,
c_dtype
);
Array
<
Integer
>
result
;
if
(
success
)
{
result
.
push_back
(
Integer
(
meta
.
atom_m
));
result
.
push_back
(
Integer
(
meta
.
atom_n
));
result
.
push_back
(
Integer
(
meta
.
atom_k
));
}
return
result
;
});
});
refl
::
GlobalDef
().
def
(
"tl.get_tcgen5_instr_desc"
,
[](
int
atom_m
,
int
atom_n
,
int
atom_k
,
DataType
ab_dtype
,
DataType
c_dtype
,
bool
a_is_k_major
,
bool
b_is_k_major
,
int
scale_in_a
,
int
scale_in_b
)
{
uint32_t
desc
=
GetTCGEN5InstrDesc
(
atom_m
,
atom_n
,
atom_k
,
ab_dtype
,
c_dtype
,
a_is_k_major
,
b_is_k_major
,
scale_in_a
,
scale_in_b
);
return
Integer
(
static_cast
<
int64_t
>
(
desc
));
});
}
}
// namespace tl
}
// namespace tvm
src/op/gemm_py.h
View file @
bbbf4207
...
...
@@ -18,87 +18,54 @@ using namespace tir;
class
GemmPyNode
:
public
TileOperatorNode
{
public:
bool
CheckWGMMA
()
const
;
tir
::
Buffer
A
,
B
,
C
;
// pointer to the A, B, C
PrimExpr
Aptr
,
Bptr
,
Cptr
;
bool
trans_A
,
trans_B
;
int
M
,
N
,
K
;
int
stride_A
,
stride_B
;
int
offset_A
,
offset_B
;
PrimExpr
clear_accum
=
const_false
();
bool
checkWgmma
()
const
;
bool
allowTcgen5Mma
(
Target
target
)
const
;
bool
allowWgmma
(
int
block_size
,
Target
target
)
const
;
tir
::
Buffer
a_
,
b_
,
c_
;
// BufferRegion for A, B and C
BufferRegion
aRegion_
,
bRegion_
,
cRegion_
;
bool
transA_
,
transB_
;
int
m_
,
n_
,
k_
;
int
strideA_
,
strideB_
;
int
offsetA_
,
offsetB_
;
PrimExpr
clearAccum_
=
const_false
();
PrimExpr
mbarPtr_
;
std
::
optional
<
tir
::
Buffer
>
mbar_
;
// mbar is optional, only used for TCGEN5MMA
Array
<
PrimExpr
>
cCoords_
;
// k_pack please ref to bitblas/tl/mfma_macro_generator.py::k_pack
// only will be enabled under cdna mfma instructions
int
kPack
=
1
;
int
wg
_w
ait
=
0
;
mutable
GemmWarpPolicy
policy
;
int
kPack
_
=
1
;
int
wg
W
ait
_
=
0
;
mutable
GemmWarpPolicy
policy
_
;
static
constexpr
const
char
*
_type_key
=
"tl.GemmPy"
;
TVM_DECLARE_FINAL_OBJECT_INFO
(
GemmPyNode
,
TileOperatorNode
);
TVM_FFI_DECLARE_OBJECT_INFO_FINAL
(
"tl.GemmPy"
,
GemmPyNode
,
TileOperatorNode
);
static
void
RegisterReflection
()
{
namespace
refl
=
tvm
::
ffi
::
reflection
;
refl
::
ObjectDef
<
GemmPyNode
>
()
.
def_ro
(
"A"
,
&
GemmPyNode
::
A
)
.
def_ro
(
"B"
,
&
GemmPyNode
::
B
)
.
def_ro
(
"C"
,
&
GemmPyNode
::
C
)
.
def_ro
(
"Aptr"
,
&
GemmPyNode
::
Aptr
)
.
def_ro
(
"Bptr"
,
&
GemmPyNode
::
Bptr
)
.
def_ro
(
"Cptr"
,
&
GemmPyNode
::
Cptr
)
.
def_ro
(
"trans_A"
,
&
GemmPyNode
::
trans_A
)
.
def_ro
(
"trans_B"
,
&
GemmPyNode
::
trans_B
)
.
def_ro
(
"M"
,
&
GemmPyNode
::
M
)
.
def_ro
(
"N"
,
&
GemmPyNode
::
N
)
.
def_ro
(
"K"
,
&
GemmPyNode
::
K
)
.
def_ro
(
"stride_A"
,
&
GemmPyNode
::
stride_A
)
.
def_ro
(
"stride_B"
,
&
GemmPyNode
::
stride_B
)
.
def_ro
(
"offset_A"
,
&
GemmPyNode
::
offset_A
)
.
def_ro
(
"offset_B"
,
&
GemmPyNode
::
offset_B
)
.
def_ro
(
"clear_accum"
,
&
GemmPyNode
::
clear_accum
)
.
def_ro
(
"kPack"
,
&
GemmPyNode
::
kPack
)
.
def_ro
(
"wg_wait"
,
&
GemmPyNode
::
wg_wait
)
.
def_ro
(
"policy"
,
&
GemmPyNode
::
policy
);
.
def_ro
(
"a"
,
&
GemmPyNode
::
a_
)
.
def_ro
(
"b"
,
&
GemmPyNode
::
b_
)
.
def_ro
(
"c"
,
&
GemmPyNode
::
c_
)
.
def_ro
(
"aRegion"
,
&
GemmPyNode
::
aRegion_
)
.
def_ro
(
"bRegion"
,
&
GemmPyNode
::
bRegion_
)
.
def_ro
(
"cRegion"
,
&
GemmPyNode
::
cRegion_
)
.
def_ro
(
"transA"
,
&
GemmPyNode
::
transA_
)
.
def_ro
(
"transB"
,
&
GemmPyNode
::
transB_
)
.
def_ro
(
"m"
,
&
GemmPyNode
::
m_
)
.
def_ro
(
"n"
,
&
GemmPyNode
::
n_
)
.
def_ro
(
"k"
,
&
GemmPyNode
::
k_
)
.
def_ro
(
"strideA"
,
&
GemmPyNode
::
strideA_
)
.
def_ro
(
"strideB"
,
&
GemmPyNode
::
strideB_
)
.
def_ro
(
"offsetA"
,
&
GemmPyNode
::
offsetA_
)
.
def_ro
(
"offsetB"
,
&
GemmPyNode
::
offsetB_
)
.
def_ro
(
"clearAccum"
,
&
GemmPyNode
::
clearAccum_
)
.
def_ro
(
"mbarPtr"
,
&
GemmPyNode
::
mbarPtr_
)
.
def_ro
(
"cCoords"
,
&
GemmPyNode
::
cCoords_
)
.
def_ro
(
"kPack"
,
&
GemmPyNode
::
kPack_
)
.
def_ro
(
"wgWait"
,
&
GemmPyNode
::
wgWait_
)
.
def_ro
(
"policy"
,
&
GemmPyNode
::
policy_
);
}
bool
SEqualReduce
(
const
GemmPyNode
*
other
,
SEqualReducer
equal
)
const
{
return
equal
(
A
,
other
->
A
)
&&
equal
(
B
,
other
->
B
)
&&
equal
(
C
,
other
->
C
)
&&
equal
(
Aptr
,
other
->
Aptr
)
&&
equal
(
Bptr
,
other
->
Bptr
)
&&
equal
(
Cptr
,
other
->
Cptr
)
&&
equal
(
trans_A
,
other
->
trans_A
)
&&
equal
(
trans_B
,
other
->
trans_B
)
&&
equal
(
M
,
other
->
M
)
&&
equal
(
N
,
other
->
N
)
&&
equal
(
K
,
other
->
K
)
&&
equal
(
stride_A
,
other
->
stride_A
)
&&
equal
(
stride_B
,
other
->
stride_B
)
&&
equal
(
offset_A
,
other
->
offset_B
)
&&
equal
(
offset_B
,
other
->
offset_B
)
&&
equal
(
clear_accum
,
other
->
clear_accum
)
&&
equal
(
kPack
,
other
->
kPack
)
&&
equal
(
wg_wait
,
other
->
wg_wait
)
&&
equal
(
policy
,
other
->
policy
);
}
void
SHashReduce
(
SHashReducer
hash_reduce
)
const
{
hash_reduce
(
A
);
hash_reduce
(
B
);
hash_reduce
(
C
);
hash_reduce
(
Aptr
);
hash_reduce
(
Bptr
);
hash_reduce
(
Cptr
);
hash_reduce
(
trans_A
);
hash_reduce
(
trans_B
);
hash_reduce
(
M
);
hash_reduce
(
N
);
hash_reduce
(
K
);
hash_reduce
(
stride_A
);
hash_reduce
(
stride_B
);
hash_reduce
(
offset_A
);
hash_reduce
(
offset_B
);
hash_reduce
(
clear_accum
);
hash_reduce
(
kPack
);
hash_reduce
(
wg_wait
);
hash_reduce
(
policy
);
}
static
constexpr
bool
_type_has_method_sequal_reduce
=
true
;
static
constexpr
bool
_type_has_method_shash_reduce
=
true
;
Stmt
Lower
(
const
LowerArgs
&
T
,
arith
::
Analyzer
*
analyzer
)
const
override
;
LayoutMap
InferLayout
(
const
LayoutInferArgs
&
T
,
InferLevel
level
)
const
override
;
...
...
@@ -106,7 +73,7 @@ public:
TileOperator
Clone
()
const
;
// Target GEMM instruction
GemmInst
G
etGemmInst
(
int
block_size
,
Target
target
)
const
;
GemmInst
g
etGemmInst
(
int
block_size
,
Target
target
)
const
;
private:
mutable
bool
completed_
=
false
;
...
...
@@ -114,7 +81,7 @@ private:
class
GemmPy
:
public
TileOperator
{
public:
TVM_DEFINE_OBJECT_REF_METHODS
(
GemmPy
,
TileOperator
,
GemmPyNode
);
TVM_
FFI_
DEFINE_OBJECT_REF_METHODS
_NULLABLE
(
GemmPy
,
TileOperator
,
GemmPyNode
);
TVM_DLL
GemmPy
(
Array
<
PrimExpr
>
args
,
BufferMap
vmap
);
static
const
Op
&
Get
();
};
...
...
src/op/gemm_sp.cc
View file @
bbbf4207
...
...
@@ -18,14 +18,14 @@
namespace
tvm
{
namespace
tl
{
std
::
pair
<
int
,
int
>
GemmSPWarpPolicyNode
::
C
omputeWarpPartition
(
int
M
,
int
N
,
std
::
pair
<
int
,
int
>
GemmSPWarpPolicyNode
::
c
omputeWarpPartition
(
int
M
,
int
N
,
int
block_size
,
Target
target
,
bool
use_wgmma
,
int
bits
)
const
{
int
num_warps
=
block_size
/
TargetGetWarpSize
(
target
);
auto
[
m_warp
,
n_warp
]
=
GemmWarpPolicyNode
::
C
omputeWarpPartition
(
auto
[
m_warp
,
n_warp
]
=
GemmWarpPolicyNode
::
c
omputeWarpPartition
(
M
,
N
,
block_size
,
target
,
use_wgmma
?
GemmInst
::
kWGMMA
:
GemmInst
::
kMMA
);
// Special handling for gemm_sp when the tiling size is not a multiple
...
...
@@ -84,26 +84,26 @@ std::pair<int, int> GemmSPWarpPolicyNode::ComputeWarpPartition(int M, int N,
* @note An ICHECK failure is raised if a provided kPack is not 1 or 2.
*/
GemmSP
::
GemmSP
(
Array
<
PrimExpr
>
args
,
BufferMap
vmap
)
{
ObjectPtr
<
GemmSPNode
>
node
=
make_object
<
GemmSPNode
>
();
node
->
A
=
vmap
[
GetVarFromAccessPtr
(
args
[
0
])];
node
->
E
=
vmap
[
GetVarFromAccessPtr
(
args
[
1
])];
node
->
B
=
vmap
[
GetVarFromAccessPtr
(
args
[
2
])];
node
->
C
=
vmap
[
GetVarFromAccessPtr
(
args
[
3
])];
node
->
trans
_
A
=
args
[
4
].
as
<
Bool
>
().
value
();
node
->
trans
_
B
=
args
[
5
].
as
<
Bool
>
().
value
();
node
->
M
=
args
[
6
].
as
<
IntImm
>
().
value
()
->
value
;
node
->
N
=
args
[
7
].
as
<
IntImm
>
().
value
()
->
value
;
node
->
K
=
args
[
8
].
as
<
IntImm
>
().
value
()
->
value
;
node
->
policy
=
GemmSPWarpPolicy
(
args
[
9
].
as
<
IntImm
>
().
value
()
->
value
);
node
->
clear
_a
ccum
=
args
[
10
].
as
<
Bool
>
().
value
();
ObjectPtr
<
GemmSPNode
>
node
=
tvm
::
ffi
::
make_object
<
GemmSPNode
>
();
node
->
a_
=
vmap
[
GetVarFromAccessPtr
(
args
[
0
])];
node
->
e_
=
vmap
[
GetVarFromAccessPtr
(
args
[
1
])];
node
->
b_
=
vmap
[
GetVarFromAccessPtr
(
args
[
2
])];
node
->
c_
=
vmap
[
GetVarFromAccessPtr
(
args
[
3
])];
node
->
transA
_
=
args
[
4
].
as
<
Bool
>
().
value
();
node
->
transB
_
=
args
[
5
].
as
<
Bool
>
().
value
();
node
->
m_
=
args
[
6
].
as
<
IntImm
>
().
value
()
->
value
;
node
->
n_
=
args
[
7
].
as
<
IntImm
>
().
value
()
->
value
;
node
->
k_
=
args
[
8
].
as
<
IntImm
>
().
value
()
->
value
;
node
->
policy
_
=
GemmSPWarpPolicy
(
args
[
9
].
as
<
IntImm
>
().
value
()
->
value
);
node
->
clear
A
ccum
_
=
args
[
10
].
as
<
Bool
>
().
value
();
if
(
args
.
size
()
>
11
)
{
node
->
kPack
=
args
[
11
].
as
<
IntImm
>
().
value
()
->
value
;
if
(
node
->
kPack
!=
1
&&
node
->
kPack
!=
2
)
{
node
->
kPack
_
=
args
[
11
].
as
<
IntImm
>
().
value
()
->
value
;
if
(
node
->
kPack
_
!=
1
&&
node
->
kPack
_
!=
2
)
{
ICHECK
(
false
)
<<
"kPack must be 1 or 2"
;
}
}
if
(
args
.
size
()
>
12
)
{
node
->
wg
_w
ait
=
args
[
12
].
as
<
IntImm
>
().
value
()
->
value
;
node
->
wg
W
ait
_
=
args
[
12
].
as
<
IntImm
>
().
value
()
->
value
;
}
data_
=
std
::
move
(
node
);
}
...
...
@@ -118,7 +118,7 @@ GemmSP::GemmSP(Array<PrimExpr> args, BufferMap vmap) {
* @return TileOperator A TileOperator holding a cloned GemmSPNode.
*/
TileOperator
GemmSPNode
::
Clone
()
const
{
auto
op
=
make_object
<
GemmSPNode
>
(
*
this
);
auto
op
=
tvm
::
ffi
::
make_object
<
GemmSPNode
>
(
*
this
);
return
GemmSP
(
op
);
}
...
...
@@ -144,37 +144,37 @@ Stmt GemmSPNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
int
warp_size
=
32
;
auto
block_size
=
*
as_const_int
(
T
.
thread_bounds
->
extent
);
bool
maybe_wgmma
=
TargetIsHopper
(
T
.
target
)
&&
(
this
->
M
>=
64
)
&&
bool
maybe_wgmma
=
TargetIsHopper
(
T
.
target
)
&&
(
this
->
m_
>=
64
)
&&
(
block_size
/
warp_size
%
4
==
0
);
auto
[
warp_m
,
warp_n
]
=
policy
->
C
omputeWarpPartition
(
M
,
N
,
block_size
,
T
.
target
,
maybe_wgmma
,
A
->
dtype
.
bits
());
auto
[
warp_m
,
warp_n
]
=
policy
_
->
c
omputeWarpPartition
(
m_
,
n_
,
block_size
,
T
.
target
,
maybe_wgmma
,
a_
->
dtype
.
bits
());
std
::
stringstream
ss
;
std
::
string
op_name
=
"tl::gemm_sp_ss"
;
ICHECK
((
A
.
scope
()
==
"shared"
||
A
.
scope
()
==
"shared.dyn"
)
&&
(
B
.
scope
()
==
"shared"
||
B
.
scope
()
==
"shared.dyn"
))
<<
"Only support shared.dyn scope for A and B, but received "
<<
A
.
scope
()
<<
" and "
<<
B
.
scope
();
ICHECK
((
E
.
scope
()
==
"shared"
||
E
.
scope
()
==
"shared.dyn"
))
ICHECK
((
a_
.
scope
()
==
"shared"
||
a_
.
scope
()
==
"shared.dyn"
)
&&
(
b_
.
scope
()
==
"shared"
||
b_
.
scope
()
==
"shared.dyn"
))
<<
"Only support shared.dyn scope for A and B, but received "
<<
a_
.
scope
()
<<
" and "
<<
b_
.
scope
();
ICHECK
((
e_
.
scope
()
==
"shared"
||
e_
.
scope
()
==
"shared.dyn"
))
<<
"Only support shared.dyn scope for E as copy from smem to rmem are "
"delegated to cute implementation, found "
<<
E
.
scope
();
ss
<<
op_name
<<
"<"
<<
M
<<
", "
<<
N
<<
", "
<<
K
<<
", "
;
<<
e_
.
scope
();
ss
<<
op_name
<<
"<"
<<
m_
<<
", "
<<
n_
<<
", "
<<
k_
<<
", "
;
ss
<<
warp_m
<<
", "
<<
warp_n
<<
", "
;
ss
<<
trans
_
A
<<
", "
<<
trans
_
B
;
ss
<<
", "
<<
clear
_a
ccum
;
ss
<<
transA
_
<<
", "
<<
transB
_
;
ss
<<
", "
<<
clear
A
ccum
_
;
if
(
TargetIsHopper
(
T
.
target
))
{
ss
<<
", "
<<
(
maybe_wgmma
?
"true"
:
"false"
);
}
if
(
wg
_w
ait
!=
0
)
{
ss
<<
", "
<<
wg
_w
ait
;
if
(
wg
W
ait
_
!=
0
)
{
ss
<<
", "
<<
wg
W
ait
_
;
}
ss
<<
">"
;
auto
A_buffer
=
T
.
buffer_remap
.
count
(
A
)
?
T
.
buffer_remap
[
A
]
:
A
;
auto
B_buffer
=
T
.
buffer_remap
.
count
(
B
)
?
T
.
buffer_remap
[
B
]
:
B
;
auto
C_buffer
=
T
.
buffer_remap
[
C
];
auto
E_buffer
=
T
.
buffer_remap
.
count
(
E
)
?
T
.
buffer_remap
[
E
]
:
E
;
auto
A_buffer
=
T
.
buffer_remap
.
count
(
a_
)
?
T
.
buffer_remap
[
a_
]
:
a_
;
auto
B_buffer
=
T
.
buffer_remap
.
count
(
b_
)
?
T
.
buffer_remap
[
b_
]
:
b_
;
auto
C_buffer
=
T
.
buffer_remap
[
c_
];
auto
E_buffer
=
T
.
buffer_remap
.
count
(
e_
)
?
T
.
buffer_remap
[
e_
]
:
e_
;
auto
new_call
=
Call
(
DataType
::
Handle
(),
tl
::
tl_gemm_sp
(),
...
...
@@ -217,59 +217,59 @@ LayoutMap GemmSPNode::InferLayout(const LayoutInferArgs &T,
if
(
completed_
)
return
{};
LayoutMap
results
;
ICHECK
(
C
.
scope
()
==
"local.fragment"
);
ICHECK
(
c_
.
scope
()
==
"local.fragment"
);
auto
thread_range
=
T
.
thread_bounds
;
auto
block_size
=
*
as_const_int
(
thread_range
->
extent
);
if
(
TargetIsHopper
(
T
.
target
))
{
const
int
warp_size
=
32
;
constexpr
int
wgmma_m
=
16
*
4
;
bool
maybe_wgmma
=
(
this
->
M
>=
wgmma_m
)
&&
(
block_size
/
warp_size
%
4
==
0
);
auto
[
warp_m
,
warp_n
]
=
policy
->
C
omputeWarpPartition
(
M
,
N
,
block_size
,
T
.
target
,
maybe_wgmma
,
A
->
dtype
.
bits
());
auto
fragment
=
maybe_wgmma
?
makeGemmFragmentCHopper
(
M
,
N
,
M
/
warp_m
,
N
/
warp_n
,
C
->
dtype
.
bits
())
:
makeGemmFragmentC
(
M
,
N
,
M
/
warp_m
,
N
/
warp_n
,
C
->
dtype
.
bits
());
results
.
Set
(
C
,
fragment
->
BindThreadRange
(
thread_range
));
if
(
A
.
scope
()
==
"shared"
||
A
.
scope
()
==
"shared.dyn"
)
{
int
dim_A
=
A
->
shape
.
size
();
const
int64_t
mat_stride
=
*
as_const_int
(
A
->
shape
[
dim_A
-
2
]);
const
int64_t
mat_continuous
=
*
as_const_int
(
A
->
shape
[
dim_A
-
1
]);
results
.
Set
(
A
,
makeGemmABLayoutHopper
(
mat_stride
,
mat_continuous
,
mat_continuous
,
A
->
dtype
.
bits
(),
trans
_
A
?
1
:
2
));
(
this
->
m_
>=
wgmma_m
)
&&
(
block_size
/
warp_size
%
4
==
0
);
auto
[
warp_m
,
warp_n
]
=
policy
_
->
c
omputeWarpPartition
(
m_
,
n_
,
block_size
,
T
.
target
,
maybe_wgmma
,
a_
->
dtype
.
bits
());
auto
fragment
=
maybe_wgmma
?
makeGemmFragmentCHopper
(
m_
,
n_
,
m_
/
warp_m
,
n_
/
warp_n
,
c_
->
dtype
.
bits
())
:
makeGemmFragmentC
(
m_
,
n_
,
m_
/
warp_m
,
n_
/
warp_n
,
c_
->
dtype
.
bits
());
results
.
Set
(
c_
,
fragment
->
BindThreadRange
(
thread_range
));
if
(
a_
.
scope
()
==
"shared"
||
a_
.
scope
()
==
"shared.dyn"
)
{
int
dim_A
=
a_
->
shape
.
size
();
const
int64_t
mat_stride
=
*
as_const_int
(
a_
->
shape
[
dim_A
-
2
]);
const
int64_t
mat_continuous
=
*
as_const_int
(
a_
->
shape
[
dim_A
-
1
]);
results
.
Set
(
a_
,
makeGemmABLayoutHopper
(
mat_stride
,
mat_continuous
,
mat_continuous
,
a_
->
dtype
.
bits
(),
transA
_
?
1
:
2
));
}
else
{
ICHECK
(
false
)
<<
"Not implemented"
;
}
if
(
B
.
scope
()
==
"shared"
||
B
.
scope
()
==
"shared.dyn"
)
{
int
dim_B
=
B
->
shape
.
size
();
const
int64_t
mat_stride
=
*
as_const_int
(
B
->
shape
[
dim_B
-
2
]);
const
int64_t
mat_continuous
=
*
as_const_int
(
B
->
shape
[
dim_B
-
1
]);
if
(
b_
.
scope
()
==
"shared"
||
b_
.
scope
()
==
"shared.dyn"
)
{
int
dim_B
=
b_
->
shape
.
size
();
const
int64_t
mat_stride
=
*
as_const_int
(
b_
->
shape
[
dim_B
-
2
]);
const
int64_t
mat_continuous
=
*
as_const_int
(
b_
->
shape
[
dim_B
-
1
]);
const
int64_t
continuity
=
trans
_
B
?
mat_continuous
:
mat_continuous
/
warp_n
;
results
.
Set
(
B
,
transB
_
?
mat_continuous
:
mat_continuous
/
warp_n
;
results
.
Set
(
b_
,
makeGemmABLayoutHopper
(
mat_stride
,
mat_continuous
,
continuity
,
B
->
dtype
.
bits
(),
trans
_
B
?
2
:
1
));
b_
->
dtype
.
bits
(),
transB
_
?
2
:
1
));
}
else
{
ICHECK
(
false
)
<<
"WGMMA only support B in shared."
;
}
}
else
if
(
TargetIsAmpere
(
T
.
target
))
{
auto
[
warp_m
,
warp_n
]
=
policy
->
C
omputeWarpPartition
(
M
,
N
,
block_size
,
T
.
target
,
false
,
A
->
dtype
.
bits
());
auto
fragment
=
makeGemmSparseFragmentC
(
M
,
N
,
M
/
warp_m
,
N
/
warp_n
,
C
->
dtype
.
bits
());
results
.
Set
(
C
,
fragment
->
BindThreadRange
(
thread_range
));
auto
[
warp_m
,
warp_n
]
=
policy
_
->
c
omputeWarpPartition
(
m_
,
n_
,
block_size
,
T
.
target
,
false
,
a_
->
dtype
.
bits
());
auto
fragment
=
makeGemmSparseFragmentC
(
m_
,
n_
,
m_
/
warp_m
,
n_
/
warp_n
,
c_
->
dtype
.
bits
());
results
.
Set
(
c_
,
fragment
->
BindThreadRange
(
thread_range
));
if
(
A
.
scope
()
==
"shared"
||
A
.
scope
()
==
"shared.dyn"
)
{
int
dim_A
=
A
->
shape
.
size
();
const
int64_t
mat_stride
=
*
as_const_int
(
A
->
shape
[
dim_A
-
2
]);
const
int64_t
mat_continuous
=
*
as_const_int
(
A
->
shape
[
dim_A
-
1
]);
results
.
Set
(
A
,
makeGemmSparseAmpereABLayout
(
mat_stride
,
mat_continuous
,
A
->
dtype
.
bits
()));
}
else
if
(
A
.
scope
()
==
"local.fragment"
)
{
if
(
a_
.
scope
()
==
"shared"
||
a_
.
scope
()
==
"shared.dyn"
)
{
int
dim_A
=
a_
->
shape
.
size
();
const
int64_t
mat_stride
=
*
as_const_int
(
a_
->
shape
[
dim_A
-
2
]);
const
int64_t
mat_continuous
=
*
as_const_int
(
a_
->
shape
[
dim_A
-
1
]);
results
.
Set
(
a_
,
makeGemmSparseAmpereABLayout
(
mat_stride
,
mat_continuous
,
a_
->
dtype
.
bits
()));
}
else
if
(
a_
.
scope
()
==
"local.fragment"
)
{
// auto fragment = makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n,
// A->dtype.bits(), trans_A);
// results.Set(A, fragment->BindThreadRange(thread_range));
...
...
@@ -277,13 +277,13 @@ LayoutMap GemmSPNode::InferLayout(const LayoutInferArgs &T,
}
else
{
ICHECK
(
0
);
}
if
(
B
.
scope
()
==
"shared"
||
B
.
scope
()
==
"shared.dyn"
)
{
int
dim_B
=
B
->
shape
.
size
();
const
int64_t
mat_stride
=
*
as_const_int
(
B
->
shape
[
dim_B
-
2
]);
const
int64_t
mat_continuous
=
*
as_const_int
(
B
->
shape
[
dim_B
-
1
]);
results
.
Set
(
B
,
makeGemmSparseAmpereABLayout
(
mat_stride
,
mat_continuous
,
B
->
dtype
.
bits
()));
}
else
if
(
B
.
scope
()
==
"local.fragment"
)
{
if
(
b_
.
scope
()
==
"shared"
||
b_
.
scope
()
==
"shared.dyn"
)
{
int
dim_B
=
b_
->
shape
.
size
();
const
int64_t
mat_stride
=
*
as_const_int
(
b_
->
shape
[
dim_B
-
2
]);
const
int64_t
mat_continuous
=
*
as_const_int
(
b_
->
shape
[
dim_B
-
1
]);
results
.
Set
(
b_
,
makeGemmSparseAmpereABLayout
(
mat_stride
,
mat_continuous
,
b_
->
dtype
.
bits
()));
}
else
if
(
b_
.
scope
()
==
"local.fragment"
)
{
// auto fragment =
// makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n, trans_B);
// results.Set(B, fragment->BindThreadRange(thread_range));
...
...
@@ -303,7 +303,7 @@ TIR_REGISTER_TL_OP(GemmSP, gemm_sp)
.
set_attr
<
TCallEffectKind
>
(
"TCallEffectKind"
,
Integer
(
CallEffectKind
::
kOpaque
));
TVM_FFI_STATIC_INIT_BLOCK
({
GemmSPNode
::
RegisterReflection
();
}
);
TVM_FFI_STATIC_INIT_BLOCK
(
)
{
GemmSPNode
::
RegisterReflection
();
}
}
// namespace tl
}
// namespace tvm
src/op/gemm_sp.h
View file @
bbbf4207
...
...
@@ -18,30 +18,32 @@ using namespace tir;
class
GemmSPWarpPolicyNode
:
public
GemmWarpPolicyNode
{
public:
std
::
pair
<
int
,
int
>
C
omputeWarpPartition
(
int
M
,
int
N
,
int
block_size
,
std
::
pair
<
int
,
int
>
c
omputeWarpPartition
(
int
M
,
int
N
,
int
block_size
,
Target
target
,
bool
use_wgmma
,
int
bits
)
const
;
TVM_FFI_DECLARE_OBJECT_INFO
(
"tl.GemmSPWarpPolicy"
,
GemmSPWarpPolicyNode
,
GemmWarpPolicyNode
);
};
class
GemmSPWarpPolicy
:
public
ObjectRef
{
public:
TVM_DEFINE_OBJECT_REF_METHODS
(
GemmSPWarpPolicy
,
ObjectRef
,
TVM_
FFI_
DEFINE_OBJECT_REF_METHODS
_NULLABLE
(
GemmSPWarpPolicy
,
ObjectRef
,
GemmSPWarpPolicyNode
);
explicit
GemmSPWarpPolicy
(
GemmWarpPolicyType
policy_type
)
{
auto
node
=
make_object
<
GemmSPWarpPolicyNode
>
();
auto
node
=
tvm
::
ffi
::
make_object
<
GemmSPWarpPolicyNode
>
();
node
->
policy_type
=
(
int
)
policy_type
;
data_
=
std
::
move
(
node
);
}
explicit
GemmSPWarpPolicy
(
int
policy_type
)
{
auto
node
=
make_object
<
GemmSPWarpPolicyNode
>
();
auto
node
=
tvm
::
ffi
::
make_object
<
GemmSPWarpPolicyNode
>
();
node
->
policy_type
=
policy_type
;
data_
=
std
::
move
(
node
);
}
explicit
GemmSPWarpPolicy
(
int
m_warp
,
int
n_warp
)
{
auto
node
=
make_object
<
GemmSPWarpPolicyNode
>
();
auto
node
=
tvm
::
ffi
::
make_object
<
GemmSPWarpPolicyNode
>
();
node
->
m_warp
=
m_warp
;
node
->
n_warp
=
n_warp
;
node
->
policy_type
=
(
int
)
GemmWarpPolicyType
::
kFree
;
...
...
@@ -51,19 +53,18 @@ public:
class
GemmSPNode
:
public
TileOperatorNode
{
public:
tir
::
Buffer
A
,
B
,
C
,
E
;
bool
trans
_
A
,
trans
_
B
;
int
M
,
N
,
K
;
bool
clear
_a
ccum
=
false
;
tir
::
Buffer
a_
,
b_
,
c_
,
e_
;
bool
transA
_
,
transB
_
;
int
m_
,
n_
,
k_
;
bool
clear
A
ccum
_
=
false
;
// k_pack please ref to bitblas/tl/mfma_macro_generator.py::k_pack
// only will be enabled under cdna mfma instructions
int
kPack
=
1
;
int
wg
_w
ait
=
0
;
int
kPack
_
=
1
;
int
wg
W
ait
_
=
0
;
mutable
GemmSPWarpPolicy
policy
;
mutable
GemmSPWarpPolicy
policy
_
;
static
constexpr
const
char
*
_type_key
=
"tl.GemmSP"
;
TVM_DECLARE_FINAL_OBJECT_INFO
(
GemmSPNode
,
TileOperatorNode
);
TVM_FFI_DECLARE_OBJECT_INFO_FINAL
(
"tl.GemmSP"
,
GemmSPNode
,
TileOperatorNode
);
Stmt
Lower
(
const
LowerArgs
&
T
,
arith
::
Analyzer
*
analyzer
)
const
override
;
LayoutMap
InferLayout
(
const
LayoutInferArgs
&
T
,
InferLevel
level
)
const
override
;
...
...
@@ -73,44 +74,19 @@ public:
static
void
RegisterReflection
()
{
namespace
refl
=
tvm
::
ffi
::
reflection
;
refl
::
ObjectDef
<
GemmSPNode
>
()
.
def_ro
(
"policy"
,
&
GemmSPNode
::
policy
)
.
def_ro
(
"A"
,
&
GemmSPNode
::
A
)
.
def_ro
(
"B"
,
&
GemmSPNode
::
B
)
.
def_ro
(
"C"
,
&
GemmSPNode
::
C
)
.
def_ro
(
"E"
,
&
GemmSPNode
::
E
)
.
def_ro
(
"trans_A"
,
&
GemmSPNode
::
trans_A
)
.
def_ro
(
"trans_B"
,
&
GemmSPNode
::
trans_B
)
.
def_ro
(
"M"
,
&
GemmSPNode
::
M
)
.
def_ro
(
"N"
,
&
GemmSPNode
::
N
)
.
def_ro
(
"K"
,
&
GemmSPNode
::
K
)
.
def_ro
(
"clear_accum"
,
&
GemmSPNode
::
clear_accum
)
.
def_ro
(
"kPack"
,
&
GemmSPNode
::
kPack
)
.
def_ro
(
"wg_wait"
,
&
GemmSPNode
::
wg_wait
);
}
bool
SEqualReduce
(
const
GemmSPNode
*
other
,
SEqualReducer
equal
)
const
{
return
equal
(
A
,
other
->
A
)
&&
equal
(
B
,
other
->
B
)
&&
equal
(
C
,
other
->
C
)
&&
equal
(
E
,
other
->
E
)
&&
equal
(
trans_A
,
other
->
trans_A
)
&&
equal
(
trans_B
,
other
->
trans_B
)
&&
equal
(
M
,
other
->
M
)
&&
equal
(
N
,
other
->
N
)
&&
equal
(
K
,
other
->
K
)
&&
equal
(
clear_accum
,
other
->
clear_accum
)
&&
equal
(
kPack
,
other
->
kPack
)
&&
equal
(
wg_wait
,
other
->
wg_wait
);
}
void
SHashReduce
(
SHashReducer
hash_reduce
)
const
{
hash_reduce
(
policy
);
hash_reduce
(
A
);
hash_reduce
(
B
);
hash_reduce
(
C
);
hash_reduce
(
E
);
hash_reduce
(
trans_A
);
hash_reduce
(
trans_B
);
hash_reduce
(
M
);
hash_reduce
(
N
);
hash_reduce
(
K
);
hash_reduce
(
clear_accum
);
hash_reduce
(
kPack
);
hash_reduce
(
wg_wait
);
.
def_ro
(
"policy"
,
&
GemmSPNode
::
policy_
)
.
def_ro
(
"a"
,
&
GemmSPNode
::
a_
)
.
def_ro
(
"b"
,
&
GemmSPNode
::
b_
)
.
def_ro
(
"c"
,
&
GemmSPNode
::
c_
)
.
def_ro
(
"e"
,
&
GemmSPNode
::
e_
)
.
def_ro
(
"transA"
,
&
GemmSPNode
::
transA_
)
.
def_ro
(
"transB"
,
&
GemmSPNode
::
transB_
)
.
def_ro
(
"m"
,
&
GemmSPNode
::
m_
)
.
def_ro
(
"n"
,
&
GemmSPNode
::
n_
)
.
def_ro
(
"k"
,
&
GemmSPNode
::
k_
)
.
def_ro
(
"clearAccum"
,
&
GemmSPNode
::
clearAccum_
)
.
def_ro
(
"kPack"
,
&
GemmSPNode
::
kPack_
)
.
def_ro
(
"wgWait"
,
&
GemmSPNode
::
wgWait_
);
}
private:
...
...
@@ -119,7 +95,7 @@ private:
class
GemmSP
:
public
TileOperator
{
public:
TVM_DEFINE_OBJECT_REF_METHODS
(
GemmSP
,
TileOperator
,
GemmSPNode
);
TVM_
FFI_
DEFINE_OBJECT_REF_METHODS
_NULLABLE
(
GemmSP
,
TileOperator
,
GemmSPNode
);
TVM_DLL
GemmSP
(
Array
<
PrimExpr
>
args
,
BufferMap
vmap
);
static
const
Op
&
Get
();
};
...
...
src/op/logical.cc
View file @
bbbf4207
...
...
@@ -9,6 +9,8 @@
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>
#include "../support/ffi_aliases.h"
namespace
tvm
{
namespace
tl
{
using
namespace
tir
;
...
...
src/op/math.cc
View file @
bbbf4207
...
...
@@ -9,6 +9,8 @@
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>
#include "../support/ffi_aliases.h"
namespace
tvm
{
namespace
tl
{
using
namespace
tir
;
...
...
@@ -33,5 +35,31 @@ TVM_REGISTER_OP("tl.pow_of_int")
.
set_attr
<
TScriptPrinterName
>
(
"TScriptPrinterName"
,
"pow_of_int"
)
.
set_attr
<
FLowerIntrinsic
>
(
"cuda.FLowerIntrinsic"
,
pow_of_int_op
);
PrimExpr
infinity_op
(
PrimExpr
args
)
{
const
CallNode
*
call
=
args
.
as
<
CallNode
>
();
CHECK
(
call
!=
nullptr
);
const
DataType
&
dtype
=
call
->
dtype
;
ICHECK_EQ
(
dtype
.
lanes
(),
1
);
// NOTE(wt): Codegen for PrintConst:Inf will handle this based on dtype
if
(
dtype
.
is_float
())
{
if
(
dtype
.
bits
()
==
64
||
dtype
.
bits
()
==
32
||
dtype
.
bits
()
==
16
)
{
return
FloatImm
(
dtype
,
std
::
numeric_limits
<
float
>::
infinity
(),
call
->
span
);
}
}
else
if
(
dtype
.
is_bfloat16
())
{
return
FloatImm
(
dtype
,
std
::
numeric_limits
<
float
>::
infinity
(),
call
->
span
);
}
LOG
(
FATAL
)
<<
"Cannot decide infinity for type "
<<
dtype
;
throw
;
// Unreachable, keeps compiler happy
}
TVM_REGISTER_OP
(
"tl.infinity"
)
.
set_num_inputs
(
1
)
.
set_attr
<
TCallEffectKind
>
(
"TCallEffectKind"
,
Integer
(
CallEffectKind
::
kPure
))
.
set_attr
<
TScriptPrinterName
>
(
"TScriptPrinterName"
,
"infinity"
)
.
set_attr
<
FLowerIntrinsic
>
(
"cuda.FLowerIntrinsic"
,
infinity_op
);
}
// namespace tl
}
// namespace tvm
src/op/operator.cc
View file @
bbbf4207
...
...
@@ -55,7 +55,7 @@ TileOperator ParseOperator(Call call, BufferMap vmap) {
TileOperator
ParseOperator
(
Stmt
stmt
,
BufferMap
vmap
)
{
if
(
stmt
.
as
<
Evaluate
>
()
&&
stmt
.
as
<
EvaluateNode
>
()
->
value
.
as
<
CallNode
>
())
{
auto
call
=
stmt
.
as
<
EvaluateNode
>
()
->
value
.
as
<
CallNode
>
();
return
ParseOperator
(
GetRef
<
Call
>
(
call
),
vmap
);
return
ParseOperator
(
tvm
::
ffi
::
GetRef
<
Call
>
(
call
),
vmap
);
}
return
TileOperator
();
}
...
...
@@ -77,7 +77,7 @@ Var GetVarFromAccessPtr(const PrimExpr &expr) {
ICHECK
(
call
->
op
.
same_as
(
builtin
::
tvm_access_ptr
()));
auto
var
=
call
->
args
[
1
].
as
<
VarNode
>
();
ICHECK
(
var
);
return
GetRef
<
Var
>
(
var
);
return
tvm
::
ffi
::
GetRef
<
Var
>
(
var
);
}
}
// namespace tl
...
...
src/op/operator.h
View file @
bbbf4207
...
...
@@ -39,7 +39,6 @@ struct LowerArgs {
AddWorkspaceCallback
AddWorkspace
;
LayoutMap
layout_map
;
Map
<
Buffer
,
Buffer
>
buffer_remap
;
Array
<
Var
>
buffer_var_gemm
;
};
struct
LayoutInferArgs
{
...
...
@@ -62,14 +61,13 @@ public:
virtual
TileOperator
Clone
()
const
=
0
;
static
constexpr
const
char
*
_type_key
=
"tl.TileOperator"
;
TVM_DECLARE_BASE_OBJECT_INFO
(
TileOperatorNode
,
Object
);
TVM_FFI_DECLARE_OBJECT_INFO
(
"tl.TileOperator"
,
TileOperatorNode
,
Object
);
};
class
TileOperator
:
public
ObjectRef
{
public:
TVM_DEFINE_OBJECT_REF_METHODS
(
TileOperator
,
ObjectRef
,
TileOperatorNode
);
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE
(
TileOperator
,
ObjectRef
,
TileOperatorNode
);
};
Var
GetVarFromAccessPtr
(
const
PrimExpr
&
expr
);
...
...
src/op/parallel.cc
View file @
bbbf4207
...
...
@@ -178,7 +178,7 @@ ParallelOpNode::ParallelOpNode(For root) : root_(root), V(this) {
}
TileOperator
ParallelOpNode
::
Clone
()
const
{
auto
op
=
make_object
<
ParallelOpNode
>
(
*
this
);
auto
op
=
tvm
::
ffi
::
make_object
<
ParallelOpNode
>
(
*
this
);
return
ParallelOp
(
op
);
}
...
...
@@ -620,11 +620,37 @@ Fragment ParallelOpNode::CompleteBufferFragment(const Buffer &buffer) const {
if
(
IsCommonAccessIndice
(
buffer
))
{
return
loop_layout_
;
}
// Prefer a simple path: if original 2D indices form a bijective map, invert
// them directly and avoid introducing a synthetic replicate dimension.
{
auto
res2d
=
arith
::
DetectIterMap
(
indice_map_
[
buffer
],
ToVMap
(
loop_vars_
),
1
,
arith
::
IterMapLevel
::
Bijective
,
const_cast
<
arith
::
Analyzer
*>
(
&
analyzer_
));
if
(
res2d
->
errors
.
empty
())
{
Layout
ind_inv2d
=
Layout
(
loop_vars_
,
indice_map_
[
buffer
])
->
Inverse
();
PrimExpr
indice_rep_extent
=
1
;
PrimExpr
loop_rep_extent
=
loop_layout_
->
ReplicateExtent
();
PrimExpr
dest_buffer_rep_extent
=
indice_rep_extent
*
loop_rep_extent
;
Array
<
PrimExpr
>
fwd2
;
for
(
size_t
i
=
0
;
i
<
buffer
->
shape
.
size
();
i
++
)
{
fwd2
.
push_back
(
InputPlaceholder
(
i
));
}
PrimExpr
thd_b2
=
loop_layout_
->
ForwardThread
(
ind_inv2d
->
Forward
(
fwd2
),
std
::
nullopt
);
return
Fragment
(
buffer
->
shape
,
{},
thd_b2
,
dest_buffer_rep_extent
,
std
::
nullopt
)
->
CondenseReplicateVar
();
}
}
// Otherwise, infer an extra flattened iterator that captures truly-unused
// pieces of the loop space (if any), then try inversion with it.
PrimExpr
rep_b
=
MakeFlattenedExpression
(
DivideUnusedIterators
(
indice_map_
[
buffer
],
loop_vars_
,
&
analyzer_
));
auto
bijective_indice
=
indice_map_
[
buffer
];
bijective_indice
.
push_back
(
rep_b
);
Layout
ind_inv
=
Layout
(
loop_vars_
,
bijective_indice
)
->
Inverse
();
PrimExpr
indice_rep_extent
=
ind_inv
->
InputShape
().
back
();
// this is the size of rep_b
PrimExpr
loop_rep_extent
=
loop_layout_
->
ReplicateExtent
();
...
...
@@ -642,7 +668,7 @@ Fragment ParallelOpNode::CompleteBufferFragment(const Buffer &buffer) const {
->
CondenseReplicateVar
();
}
TVM_FFI_STATIC_INIT_BLOCK
({
ParallelOpNode
::
RegisterReflection
();
}
);
TVM_FFI_STATIC_INIT_BLOCK
(
)
{
ParallelOpNode
::
RegisterReflection
();
}
}
// namespace tl
}
// namespace tvm
src/op/parallel.h
View file @
bbbf4207
...
...
@@ -66,8 +66,8 @@ public:
mutable
Optional
<
PrimExpr
>
predicate_
;
// Type key for TVM object system.
static
constexpr
const
char
*
_type_key
=
"tl.ParallelOp"
;
TVM_DECLARE_FINAL_OBJECT_INFO
(
ParallelOpNode
,
TileOperatorNode
);
TVM_FFI_DECLARE_OBJECT_INFO_FINAL
(
"tl.ParallelOp"
,
ParallelOpNode
,
TileOperatorNode
);
static
void
RegisterReflection
()
{
namespace
refl
=
tvm
::
ffi
::
reflection
;
...
...
@@ -77,20 +77,6 @@ public:
.
def_ro
(
"predicate"
,
&
ParallelOpNode
::
predicate_
);
}
bool
SEqualReduce
(
const
ParallelOpNode
*
other
,
SEqualReducer
equal
)
const
{
return
equal
(
root_
,
other
->
root_
)
&&
equal
(
loop_layout_
,
other
->
loop_layout_
)
&&
equal
(
predicate_
,
other
->
predicate_
);
}
void
SHashReduce
(
SHashReducer
hash_reduce
)
const
{
hash_reduce
(
root_
);
hash_reduce
(
loop_layout_
);
hash_reduce
(
predicate_
);
}
static
constexpr
bool
_type_has_method_sequal_reduce
=
true
;
static
constexpr
bool
_type_has_method_shash_reduce
=
true
;
// Construct from a root For loop.
ParallelOpNode
(
For
root
);
...
...
@@ -150,10 +136,11 @@ private:
class
ParallelOp
:
public
TileOperator
{
public:
TVM_DEFINE_OBJECT_REF_METHODS
(
ParallelOp
,
TileOperator
,
ParallelOpNode
);
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE
(
ParallelOp
,
TileOperator
,
ParallelOpNode
);
ParallelOp
(
const
For
&
root
)
{
auto
op
=
make_object
<
ParallelOpNode
>
(
root
);
auto
op
=
tvm
::
ffi
::
make_object
<
ParallelOpNode
>
(
root
);
data_
=
std
::
move
(
op
);
}
};
...
...
src/op/reduce.cc
View file @
bbbf4207
...
...
@@ -14,6 +14,7 @@
#include "../op/parallel.h"
#include "../target/utils.h"
#include "../transform/loop_partition.h"
#include "region.h"
#include "tir/transforms/ir_utils.h"
namespace
tvm
{
...
...
@@ -21,10 +22,54 @@ namespace tl {
using
namespace
tir
;
// Normalize an argument (BufferRegion/BufferLoad/tl.region)
// to BufferRegion so Reduce can uniformly consume regions.
static
BufferRegion
NormalizeToBufferRegion
(
const
PrimExpr
&
arg
,
const
BufferMap
&
vmap
)
{
// Case 1: Already a BufferRegion
if
(
arg
->
IsInstance
<
BufferRegionNode
>
())
{
return
Downcast
<
BufferRegion
>
(
arg
);
}
// Case 2: BufferLoad — convert indices to ranges (Ramp -> lanes, else
// extent=1)
if
(
const
auto
*
load
=
arg
.
as
<
BufferLoadNode
>
())
{
Array
<
Range
>
ranges
;
for
(
const
PrimExpr
&
index
:
load
->
indices
)
{
if
(
const
auto
*
ramp
=
index
.
as
<
RampNode
>
())
{
ICHECK
(
ramp
->
stride
.
as
<
IntImmNode
>
())
<<
"Ramp stride must be IntImm"
;
ICHECK_EQ
(
ramp
->
stride
.
as
<
IntImmNode
>
()
->
value
,
1
)
<<
"Only stride-1 Ramp is supported in region conversion"
;
ICHECK
(
ramp
->
lanes
.
as
<
IntImmNode
>
())
<<
"Scalable vector lanes not supported in region conversion"
;
ranges
.
push_back
(
Range
::
FromMinExtent
(
ramp
->
base
,
ramp
->
lanes
));
}
else
{
ranges
.
push_back
(
Range
::
FromMinExtent
(
index
,
1
));
}
}
return
BufferRegion
(
load
->
buffer
,
ranges
);
}
// Case 3: Call nodes (only tl.region)
if
(
const
auto
*
call
=
arg
.
as
<
CallNode
>
())
{
// tl.region(...) — reconstruct via RegionOp
if
(
call
->
op
.
same_as
(
RegionOp
::
Get
()))
{
RegionOp
region
(
call
->
args
,
vmap
);
return
BufferRegion
(
region
->
GetBuffer
(),
region
->
GetRanges
());
}
}
LOG
(
FATAL
)
<<
"Unsupported argument for BufferRegion in reduce: "
<<
arg
;
throw
;
// Unreachable
}
ReduceOp
::
ReduceOp
(
Array
<
PrimExpr
>
args
,
BufferMap
vmap
)
{
ObjectPtr
<
ReduceOpNode
>
node
=
make_object
<
ReduceOpNode
>
();
node
->
src
=
vmap
[
GetVarFromAccessPtr
(
args
[
0
])];
node
->
dst
=
vmap
[
GetVarFromAccessPtr
(
args
[
1
])];
ObjectPtr
<
ReduceOpNode
>
node
=
tvm
::
ffi
::
make_object
<
ReduceOpNode
>
();
// Accept BufferRegion/BufferLoad/tl.region for src/dst
node
->
srcRegion_
=
NormalizeToBufferRegion
(
args
[
0
],
vmap
);
node
->
dstRegion_
=
NormalizeToBufferRegion
(
args
[
1
],
vmap
);
node
->
src
=
node
->
srcRegion_
->
buffer
;
node
->
dst
=
node
->
dstRegion_
->
buffer
;
std
::
string
reduce_type
=
args
[
2
].
as
<
StringImm
>
().
value
()
->
value
;
node
->
dim
=
args
[
3
].
as
<
IntImm
>
().
value
()
->
value
;
node
->
type
=
ReduceType
(
reduce_type
);
...
...
@@ -33,12 +78,12 @@ ReduceOp::ReduceOp(Array<PrimExpr> args, BufferMap vmap) {
}
TileOperator
ReduceOpNode
::
Clone
()
const
{
auto
op
=
make_object
<
ReduceOpNode
>
(
*
this
);
auto
op
=
tvm
::
ffi
::
make_object
<
ReduceOpNode
>
(
*
this
);
return
ReduceOp
(
op
);
}
TileOperator
CumSumOpNode
::
Clone
()
const
{
auto
op
=
make_object
<
CumSumOpNode
>
(
*
this
);
auto
op
=
tvm
::
ffi
::
make_object
<
CumSumOpNode
>
(
*
this
);
return
CumSumOp
(
op
);
}
...
...
@@ -85,6 +130,7 @@ PrimExpr ReduceOpNode::MakeInitValue() const {
return
make_zero
(
dst
->
dtype
);
}
else
{
LOG
(
FATAL
)
<<
"Unsupported reduce type: "
<<
type
->
type
;
return
PrimExpr
();
}
}
...
...
@@ -103,7 +149,7 @@ PrimExpr ReduceOpNode::MakeReduce(const PrimExpr &lhs,
}
else
if
(
type
->
isMin
())
{
return
Min
(
lhs
,
rhs
);
}
else
if
(
type
->
isAbsMax
())
{
return
Max
(
Max
(
lhs
,
rhs
),
-
Min
(
lhs
,
rhs
));
return
Max
(
tvm
::
abs
(
lhs
)
,
tvm
::
abs
(
rhs
));
}
else
if
(
type
->
isBitAnd
())
{
return
lhs
&
rhs
;
}
else
if
(
type
->
isBitOr
())
{
...
...
@@ -359,70 +405,6 @@ Stmt ReduceOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
return
body
;
}
auto
is_shared_scope
=
[](
const
std
::
string
&
scope
)
{
return
scope
==
"shared"
||
scope
==
"shared.dyn"
;
};
if
(
is_shared_scope
(
src_scope
)
&&
is_shared_scope
(
dst_scope
))
{
Buffer
src_buffer
=
get_buffer
(
this
->
src
);
Buffer
dst_buffer
=
get_buffer
(
this
->
dst
);
size_t
src_dim
=
src_buffer
->
shape
.
size
();
size_t
dst_dim
=
dst_buffer
->
shape
.
size
();
bool
is_1d_reduce
=
(
src_dim
==
dst_dim
&&
dst_dim
==
1
);
if
(
!
is_1d_reduce
)
{
ICHECK_EQ
(
src_dim
,
dst_dim
+
1
)
<<
"Reduce dimension mismatch."
;
}
else
{
ICHECK_EQ
(
dst_dim
,
1U
)
<<
"Expect scalar layout for 1D reduce."
;
}
auto
thread_extent
=
as_const_int
(
T
.
thread_bounds
->
extent
);
ICHECK
(
thread_extent
)
<<
"Shared-memory reduce requires static thread extent."
;
int
threads
=
*
thread_extent
;
if
(
TargetIsCuda
(
T
.
target
))
{
ICHECK_EQ
(
threads
%
32
,
0
)
<<
"Shared reduce expects blockDim.x to be a multiple of 32 on CUDA."
;
}
else
if
(
TargetIsRocm
(
T
.
target
))
{
ICHECK_EQ
(
threads
%
64
,
0
)
<<
"Shared reduce expects blockDim.x to be a multiple of 64 on HIP."
;
}
bool
use_abs
=
this
->
type
->
isAbsSum
()
||
this
->
type
->
isAbsMax
();
bool
need_accumulate
=
(
!
this
->
clear
)
&&
(
this
->
type
->
isSum
()
||
this
->
type
->
isAbsSum
()
||
this
->
type
->
isBitAnd
()
||
this
->
type
->
isBitOr
()
||
this
->
type
->
isBitXor
());
PrimExpr
reduce_extent
=
src_buffer
->
shape
[
this
->
dim
];
PrimExpr
tail_extent
=
make_const
(
DataType
::
Int
(
32
),
1
);
for
(
size_t
i
=
this
->
dim
+
1
;
i
<
src_dim
;
++
i
)
{
tail_extent
=
analyzer
->
Simplify
(
tail_extent
*
src_buffer
->
shape
[
i
]);
}
PrimExpr
total_dest
=
make_const
(
DataType
::
Int
(
32
),
1
);
for
(
size_t
i
=
0
;
i
<
dst_dim
;
++
i
)
{
total_dest
=
analyzer
->
Simplify
(
total_dest
*
dst_buffer
->
shape
[
i
]);
}
std
::
stringstream
ss
;
std
::
string
reducer
=
this
->
MakeCodegenReducer
();
ss
<<
"tl::SharedReduceWarp<"
<<
reducer
<<
", "
<<
threads
<<
", "
<<
(
use_abs
?
"true"
:
"false"
)
<<
", "
<<
(
need_accumulate
?
"true"
:
"false"
)
<<
">::run"
;
Array
<
PrimExpr
>
call_args
=
{
StringImm
(
ss
.
str
()),
src_buffer
.
access_ptr
(
1
),
dst_buffer
.
access_ptr
(
3
),
cast
(
DataType
::
Int
(
32
),
total_dest
),
cast
(
DataType
::
Int
(
32
),
reduce_extent
),
cast
(
DataType
::
Int
(
32
),
tail_extent
),
this
->
MakeInitValue
()};
return
Evaluate
(
Call
(
dst_buffer
->
dtype
,
builtin
::
call_extern
(),
call_args
));
}
LOG
(
FATAL
)
<<
"Reduce for buffers in scope ("
<<
src_scope
<<
", "
<<
dst_scope
<<
") is not implemented."
;
return
Stmt
();
...
...
@@ -432,6 +414,7 @@ LayoutMap ReduceOpNode::InferLayout(const LayoutInferArgs &T,
InferLevel
level
)
const
{
if
(
level
>=
InferLevel
::
kStrict
)
return
{};
if
(
src
.
scope
()
==
"local.fragment"
&&
dst
.
scope
()
==
"local.fragment"
&&
T
.
layout_map
.
count
(
src
))
{
auto
src_layout
=
T
.
layout_map
[
src
].
as
<
Fragment
>
().
value
();
...
...
@@ -452,10 +435,40 @@ LayoutMap ReduceOpNode::InferLayout(const LayoutInferArgs &T,
}
auto
thd
=
src_layout
->
ForwardThread
(
fwd
,
FloorDiv
(
ReplicationPlaceholder
(),
indice_rep_extent
));
// Ensure the thread count is divisible by the replicate extent.
// Otherwise, we cannot infer a valid fragment<->fragment layout.
{
arith
::
Analyzer
analyzer
;
PrimExpr
num_threads
=
T
.
thread_bounds
->
extent
;
// Though the dest_buffer_rep_extent will be compressed at
// CondenseReplicateVar, we need to check the divisibility here to avoid
// the issue that the thread count is not divisible by the replicate
// extent.
if
(
!
analyzer
.
CanProve
(
FloorMod
(
num_threads
,
dest_buffer_rep_extent
)
==
0
)
&&
!
analyzer
.
CanProve
(
FloorMod
(
dest_buffer_rep_extent
,
num_threads
)
==
0
))
{
ICHECK
(
false
)
<<
"ReduceOp fragment layout inference failed: "
"num_threads % replicate_extent != 0. "
<<
"This mapping requires the block's thread count to be "
"divisible by the "
<<
"replicate extent. "
<<
"Try one of: (1) choose a thread block size divisible "
"by replicate_extent; "
<<
"(2) pick a different reduce dimension or adjust the "
"source fragment layout; "
<<
"Details: num_threads="
<<
num_threads
<<
", replicate_extent="
<<
indice_rep_extent
<<
", src="
<<
src
<<
", dst="
<<
dst
;
}
}
Fragment
dst_layout
=
Fragment
(
dst
->
shape
,
{},
thd
,
dest_buffer_rep_extent
,
std
::
nullopt
)
->
CondenseReplicateVar
()
->
BindThreadRange
(
T
.
thread_bounds
);
if
(
!
T
.
layout_map
.
count
(
dst
))
return
{{
dst
,
dst_layout
}};
else
{
...
...
@@ -512,7 +525,7 @@ CumSumOp::CumSumOp(Array<PrimExpr> args, BufferMap vmap) {
/// - dim: dimension to cumsum
/// - reverse: whether to cumsum in reverse order
CHECK_EQ
(
args
.
size
(),
4
);
ObjectPtr
<
CumSumOpNode
>
node
=
make_object
<
CumSumOpNode
>
();
ObjectPtr
<
CumSumOpNode
>
node
=
tvm
::
ffi
::
make_object
<
CumSumOpNode
>
();
node
->
src
=
vmap
[
GetVarFromAccessPtr
(
args
[
0
])];
node
->
dst
=
vmap
[
GetVarFromAccessPtr
(
args
[
1
])];
node
->
dim
=
args
[
2
].
as
<
IntImm
>
().
value
()
->
value
;
...
...
@@ -567,5 +580,12 @@ TIR_REGISTER_TL_OP(CumSumOp, cumsum)
.
set_num_inputs
(
4
)
.
set_attr
<
TCallEffectKind
>
(
"TCallEffectKind"
,
Integer
(
CallEffectKind
::
kOpaque
));
TVM_FFI_STATIC_INIT_BLOCK
()
{
ReduceOpNode
::
RegisterReflection
();
CumSumOpNode
::
RegisterReflection
();
ReduceTypeNode
::
RegisterReflection
();
}
}
// namespace tl
}
// namespace tvm
Prev
1
2
3
4
5
6
7
8
9
…
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment