Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
bbbf4207
Unverified
Commit
bbbf4207
authored
Nov 14, 2025
by
guchaoyang
Committed by
GitHub
Nov 14, 2025
Browse files
Merge branch 'main' into dcu
parents
8f4628e0
5eb30a4f
Changes
286
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1677 additions
and
723 deletions
+1677
-723
src/target/rt_mod_hip.cc
src/target/rt_mod_hip.cc
+8
-7
src/target/utils.cc
src/target/utils.cc
+9
-6
src/tl_templates/cuda/common.h
src/tl_templates/cuda/common.h
+273
-2
src/tl_templates/cuda/cuda_fp8.h
src/tl_templates/cuda/cuda_fp8.h
+3
-2
src/tl_templates/cuda/gemm_mma.h
src/tl_templates/cuda/gemm_mma.h
+8
-6
src/tl_templates/cuda/gemm_sm100.h
src/tl_templates/cuda/gemm_sm100.h
+6
-4
src/tl_templates/cuda/gemm_sm90.h
src/tl_templates/cuda/gemm_sm90.h
+6
-4
src/tl_templates/cuda/gemm_sp_sm90.h
src/tl_templates/cuda/gemm_sp_sm90.h
+6
-4
src/tl_templates/cuda/instruction/mma.h
src/tl_templates/cuda/instruction/mma.h
+163
-0
src/tl_templates/cuda/instruction/mma_sm70.h
src/tl_templates/cuda/instruction/mma_sm70.h
+353
-0
src/tl_templates/cuda/instruction/tcgen05mma.h
src/tl_templates/cuda/instruction/tcgen05mma.h
+337
-0
src/tl_templates/cuda/instruction/wgmma.h
src/tl_templates/cuda/instruction/wgmma.h
+424
-600
src/tl_templates/cuda/intrin.h
src/tl_templates/cuda/intrin.h
+14
-0
src/tl_templates/cuda/reduce.h
src/tl_templates/cuda/reduce.h
+23
-56
src/tl_templates/cuda/tcgen_05.h
src/tl_templates/cuda/tcgen_05.h
+10
-6
src/transform/align_dynamic_shared_memory_allocations.cc
src/transform/align_dynamic_shared_memory_allocations.cc
+6
-6
src/transform/annotate_device_regions.cc
src/transform/annotate_device_regions.cc
+4
-4
src/transform/annotate_warp_group_reg_alloc.cc
src/transform/annotate_warp_group_reg_alloc.cc
+17
-10
src/transform/arg_binder.cc
src/transform/arg_binder.cc
+3
-3
src/transform/arg_binder.h
src/transform/arg_binder.h
+4
-3
No files found.
src/target/rt_mod_hip.cc
View file @
bbbf4207
...
@@ -37,18 +37,19 @@ ExtractFuncInfo(const IRModule &mod) {
...
@@ -37,18 +37,19 @@ ExtractFuncInfo(const IRModule &mod) {
}
}
info
.
arg_types
.
push_back
(
f
->
params
[
i
].
dtype
());
info
.
arg_types
.
push_back
(
f
->
params
[
i
].
dtype
());
}
}
if
(
auto
opt
=
f
->
GetAttr
<
Array
<
String
>>
(
tir
::
attr
::
kKernelLaunchParams
))
{
if
(
auto
opt
=
f
->
GetAttr
<
ffi
::
Array
<
ffi
::
String
>>
(
tir
::
attr
::
kKernelLaunchParams
))
{
for
(
const
auto
&
tag
:
opt
.
value
())
{
for
(
const
auto
&
tag
:
opt
.
value
())
{
info
.
launch_param_tags
.
push_back
(
tag
);
info
.
launch_param_tags
.
push_back
(
tag
);
}
}
}
}
auto
global_symbol
=
f
->
GetAttr
<
String
>
(
tvm
::
attr
::
kGlobalSymbol
);
auto
global_symbol
=
f
->
GetAttr
<
ffi
::
String
>
(
tvm
::
attr
::
kGlobalSymbol
);
fmap
[
static_cast
<
std
::
string
>
(
global_symbol
.
value
())]
=
info
;
fmap
[
static_cast
<
std
::
string
>
(
global_symbol
.
value
())]
=
info
;
}
}
return
fmap
;
return
fmap
;
}
}
runtime
::
Module
BuildTileLangHIP
(
IRModule
mod
,
Target
target
)
{
ffi
::
Module
BuildTileLangHIP
(
IRModule
mod
,
Target
target
)
{
bool
output_ssa
=
false
;
bool
output_ssa
=
false
;
CodeGenTileLangHIP
cg
;
CodeGenTileLangHIP
cg
;
cg
.
Init
(
output_ssa
);
cg
.
Init
(
output_ssa
);
...
@@ -84,7 +85,7 @@ runtime::Module BuildTileLangHIP(IRModule mod, Target target) {
...
@@ -84,7 +85,7 @@ runtime::Module BuildTileLangHIP(IRModule mod, Target target) {
return
ROCMModuleCreate
(
ptx
,
fmt
,
ExtractFuncInfo
(
mod
),
code
,
std
::
string
());
return
ROCMModuleCreate
(
ptx
,
fmt
,
ExtractFuncInfo
(
mod
),
code
,
std
::
string
());
}
}
runtime
::
Module
BuildTileLangHIPWithoutCompile
(
IRModule
mod
,
Target
target
)
{
ffi
::
Module
BuildTileLangHIPWithoutCompile
(
IRModule
mod
,
Target
target
)
{
bool
output_ssa
=
false
;
bool
output_ssa
=
false
;
CodeGenTileLangHIP
cg
;
CodeGenTileLangHIP
cg
;
cg
.
Init
(
output_ssa
);
cg
.
Init
(
output_ssa
);
...
@@ -110,13 +111,13 @@ runtime::Module BuildTileLangHIPWithoutCompile(IRModule mod, Target target) {
...
@@ -110,13 +111,13 @@ runtime::Module BuildTileLangHIPWithoutCompile(IRModule mod, Target target) {
std
::
string
());
std
::
string
());
}
}
TVM_FFI_STATIC_INIT_BLOCK
({
TVM_FFI_STATIC_INIT_BLOCK
(
)
{
namespace
refl
=
tvm
::
ffi
::
reflection
;
namespace
refl
=
tvm
::
ffi
::
reflection
;
refl
::
GlobalDef
()
refl
::
GlobalDef
()
.
def
(
"target.build.tilelang_hip"
,
BuildTileLangHIP
)
.
def
(
"target.build.tilelang_hip"
,
BuildTileLangHIP
)
.
def
(
"target.build.tilelang_hip_without_compile"
,
.
def
(
"target.build.tilelang_hip_without_compile"
,
BuildTileLangHIPWithoutCompile
);
BuildTileLangHIPWithoutCompile
);
}
);
}
}
// namespace codegen
}
// namespace codegen
}
// namespace tvm
}
// namespace tvm
\ No newline at end of file
src/target/utils.cc
View file @
bbbf4207
...
@@ -5,6 +5,9 @@
...
@@ -5,6 +5,9 @@
#include "utils.h"
#include "utils.h"
#include "../support/ffi_aliases.h"
#include <tvm/node/node.h>
namespace
tvm
{
namespace
tvm
{
namespace
tl
{
namespace
tl
{
...
@@ -16,8 +19,8 @@ bool TargetIsRocm(Target target) {
...
@@ -16,8 +19,8 @@ bool TargetIsRocm(Target target) {
}
}
int
GetArchInt
(
Target
target
)
{
int
GetArchInt
(
Target
target
)
{
auto
s
=
target
->
GetAttr
<
String
>
(
"arch"
);
auto
s
=
target
->
GetAttr
<
tvm
::
ffi
::
String
>
(
"arch"
);
ICHECK
(
s
.
defined
());
ICHECK
(
s
.
has_value
());
const
std
::
string
arch_str
=
s
.
value
();
const
std
::
string
arch_str
=
s
.
value
();
ICHECK
(
arch_str
.
size
()
>=
3
);
ICHECK
(
arch_str
.
size
()
>=
3
);
ICHECK_EQ
(
arch_str
.
compare
(
0
,
3
,
"sm_"
),
0
)
ICHECK_EQ
(
arch_str
.
compare
(
0
,
3
,
"sm_"
),
0
)
...
@@ -71,7 +74,7 @@ bool TargetIsCDNA(Target target) {
...
@@ -71,7 +74,7 @@ bool TargetIsCDNA(Target target) {
if
(
!
TargetIsRocm
(
target
))
if
(
!
TargetIsRocm
(
target
))
return
false
;
return
false
;
if
(
target
->
attrs
.
count
(
"mcpu"
))
{
if
(
target
->
attrs
.
count
(
"mcpu"
))
{
std
::
string
mcpu
=
Downcast
<
String
>
(
target
->
attrs
.
at
(
"mcpu"
));
std
::
string
mcpu
=
Downcast
<
tvm
::
ffi
::
String
>
(
target
->
attrs
.
at
(
"mcpu"
));
// if mcpu start with "gfx9", it is CDNA
// if mcpu start with "gfx9", it is CDNA
return
mcpu
.
find
(
"gfx9"
)
==
0
;
return
mcpu
.
find
(
"gfx9"
)
==
0
;
}
}
...
@@ -94,7 +97,7 @@ bool TargetHasAsyncCopy(Target target) {
...
@@ -94,7 +97,7 @@ bool TargetHasAsyncCopy(Target target) {
return
arch
>=
80
;
return
arch
>=
80
;
}
else
if
(
TargetIsCDNA
(
target
))
{
}
else
if
(
TargetIsCDNA
(
target
))
{
if
(
target
->
attrs
.
count
(
"mcpu"
))
{
if
(
target
->
attrs
.
count
(
"mcpu"
))
{
std
::
string
mcpu
=
Downcast
<
String
>
(
target
->
attrs
.
at
(
"mcpu"
));
std
::
string
mcpu
=
Downcast
<
tvm
::
ffi
::
String
>
(
target
->
attrs
.
at
(
"mcpu"
));
if
(
mcpu
.
rfind
(
"gfx9"
,
0
)
==
0
)
{
if
(
mcpu
.
rfind
(
"gfx9"
,
0
)
==
0
)
{
int
gfx_version
=
std
::
stoi
(
mcpu
.
substr
(
3
,
2
));
int
gfx_version
=
std
::
stoi
(
mcpu
.
substr
(
3
,
2
));
return
gfx_version
>=
94
;
return
gfx_version
>=
94
;
...
@@ -141,7 +144,7 @@ int TargetGetWarpSize(Target target) {
...
@@ -141,7 +144,7 @@ int TargetGetWarpSize(Target target) {
return
res
;
return
res
;
}
}
TVM_FFI_STATIC_INIT_BLOCK
({
TVM_FFI_STATIC_INIT_BLOCK
(
)
{
namespace
refl
=
tvm
::
ffi
::
reflection
;
namespace
refl
=
tvm
::
ffi
::
reflection
;
refl
::
GlobalDef
()
refl
::
GlobalDef
()
.
def
(
"tl.TargetIsCuda"
,
.
def
(
"tl.TargetIsCuda"
,
...
@@ -170,7 +173,7 @@ TVM_FFI_STATIC_INIT_BLOCK({
...
@@ -170,7 +173,7 @@ TVM_FFI_STATIC_INIT_BLOCK({
[](
Target
target
)
{
return
TargetHasBulkCopy
(
target
);
})
[](
Target
target
)
{
return
TargetHasBulkCopy
(
target
);
})
.
def
(
"tl.TargetGetWarpSize"
,
.
def
(
"tl.TargetGetWarpSize"
,
[](
Target
target
)
{
return
TargetGetWarpSize
(
target
);
});
[](
Target
target
)
{
return
TargetGetWarpSize
(
target
);
});
}
);
}
}
// namespace tl
}
// namespace tl
}
// namespace tvm
}
// namespace tvm
src/tl_templates/cuda/common.h
View file @
bbbf4207
...
@@ -10,6 +10,9 @@
...
@@ -10,6 +10,9 @@
#include <cutlass/numeric_types.h>
#include <cutlass/numeric_types.h>
#include <math_constants.h>
#include <math_constants.h>
#include <cutlass/bfloat16.h>
#include <cutlass/float8.h>
using
cutlass
::
bfloat16_t
;
using
cutlass
::
bfloat16_t
;
using
cutlass
::
half_t
;
using
cutlass
::
half_t
;
using
cutlass
::
tfloat32_t
;
using
cutlass
::
tfloat32_t
;
...
@@ -285,6 +288,138 @@ union GmmaDescriptor {
...
@@ -285,6 +288,138 @@ union GmmaDescriptor {
}
}
};
};
union
Tcgen05SMemDescriptor
{
CUTE_HOST_DEVICE
constexpr
Tcgen05SMemDescriptor
()
noexcept
:
desc_
(
0
)
{}
CUTE_HOST_DEVICE
constexpr
Tcgen05SMemDescriptor
(
uint64_t
desc
)
noexcept
:
desc_
(
desc
)
{}
CUTE_HOST_DEVICE
constexpr
Tcgen05SMemDescriptor
(
Tcgen05SMemDescriptor
const
&
t
)
noexcept
:
desc_
(
t
.
desc_
)
{}
CUTE_HOST_DEVICE
constexpr
Tcgen05SMemDescriptor
(
Tcgen05SMemDescriptor
&&
t
)
noexcept
:
desc_
(
t
.
desc_
)
{}
CUTE_HOST_DEVICE
constexpr
Tcgen05SMemDescriptor
&
operator
=
(
Tcgen05SMemDescriptor
const
&
t
)
noexcept
{
desc_
=
t
.
desc_
;
return
*
this
;
}
CUTE_HOST_DEVICE
constexpr
Tcgen05SMemDescriptor
&
operator
=
(
Tcgen05SMemDescriptor
&&
t
)
noexcept
{
desc_
=
t
.
desc_
;
return
*
this
;
}
uint64_t
desc_
;
uint32_t
reg32_
[
2
];
// Bitfield implementation avoids the need for shifts in assignment
struct
{
// start_address, bit [0,14), 4LSB not included
uint16_t
start_address_
:
14
,
:
2
;
// 14 bits [0,14), 2 bits unused
// leading dimension byte offset, bit [16,30), 4LSB not included
uint16_t
leading_byte_offset_
:
14
,
:
2
;
// 14 bits [0,14), 2 bits unused
// stride dimension byte offset, bit [32,46), 4LSB not included
uint16_t
stride_byte_offset_
:
14
,
version_
:
2
;
// 14 bits [0,14), 2 bits [14,16)
// base_offset, bit [49,52). leading_byte_offset_mode, bit [52,53).
uint8_t
:
1
,
base_offset_
:
3
,
lbo_mode_
:
1
,
:
3
;
// 1 bit unused, 3 bits [1,4), 1 bit [4,5), 3 bits unused
// layout type, bit [61,64), SWIZZLE_NONE matrix descriptor = 0,
// SWIZZLE_128B matrix descriptor = 2, SWIZZLE_64B descriptor = 4,
// SWIZZLE_32B descriptor = 6, SWIZZLE_128B_BASE32B = 1, N/A = 3, N/A = 5,
// N/A = 7
uint8_t
:
5
,
layout_type_
:
3
;
// 6 bits unused, 3 bits [5,8)
}
bitfield
;
// Separate the field, as we may only update one part of desc
struct
{
uint32_t
lo
;
uint32_t
hi
;
}
words
;
CUTE_HOST_DEVICE
constexpr
operator
uint64_t
()
const
noexcept
{
return
desc_
;
}
template
<
typename
T
>
CUTE_HOST_DEVICE
constexpr
Tcgen05SMemDescriptor
operator
+
(
const
T
&
offset
)
const
{
Tcgen05SMemDescriptor
ret
;
// Address addition is in units of 16 bytes (4 LSB not encoded)
ret
.
reg32_
[
0
]
=
reg32_
[
0
]
+
(
uint32_t
(
offset
)
>>
4
);
ret
.
reg32_
[
1
]
=
reg32_
[
1
];
return
ret
;
}
};
//
// Tcgen05 instruction descriptor (wraps cute::UMMA::InstrDescriptor layout)
//
union
Tcgen05InstrDescriptor
{
CUTE_HOST_DEVICE
constexpr
Tcgen05InstrDescriptor
()
noexcept
:
desc_
(
0
)
{}
CUTE_HOST_DEVICE
constexpr
Tcgen05InstrDescriptor
(
uint32_t
desc
)
noexcept
:
desc_
(
desc
)
{}
CUTE_HOST_DEVICE
constexpr
Tcgen05InstrDescriptor
(
Tcgen05InstrDescriptor
const
&
t
)
noexcept
:
desc_
(
t
.
desc_
)
{}
CUTE_HOST_DEVICE
constexpr
Tcgen05InstrDescriptor
(
Tcgen05InstrDescriptor
&&
t
)
noexcept
:
desc_
(
t
.
desc_
)
{}
CUTE_HOST_DEVICE
constexpr
Tcgen05InstrDescriptor
&
operator
=
(
Tcgen05InstrDescriptor
const
&
t
)
noexcept
{
desc_
=
t
.
desc_
;
return
*
this
;
}
CUTE_HOST_DEVICE
constexpr
Tcgen05InstrDescriptor
&
operator
=
(
Tcgen05InstrDescriptor
&&
t
)
noexcept
{
desc_
=
t
.
desc_
;
return
*
this
;
}
uint32_t
desc_
;
uint16_t
reg16_
[
2
];
// Bitfield implementation mirrors cute::UMMA::InstrDescriptor
struct
{
// bit [ 0, 2) : Sparse meta data id2
uint16_t
sparse_id2_
:
2
,
// bit [ 2, 3) : 0 = dense. 1 = sparse. Only valid for
// F32F16/S8/MXF8F6F4
sparse_flag_
:
1
,
// bit [ 3, 4) : 0 = no saturate. 1 = saturate. Only valid for S8
saturate_
:
1
,
// bit [ 4, 6) : 0 = F16. 1 = F32, 2 = S32
c_format_
:
2
,
// padding
:
1
,
// bit [ 7,10) : see UMMA format encoding
a_format_
:
3
,
// bit [10,13) : see UMMA format encoding
b_format_
:
3
,
// bit [13,14) : 0 = no negate. 1 = negate
a_negate_
:
1
,
// bit [14,15) : 0 = no negate. 1 = negate
b_negate_
:
1
,
// bit [15,16) : 0 = K-major. 1 = MN-major
a_major_
:
1
;
// Upper 16 bits
uint16_t
b_major_
:
1
,
// bit [16,17)
n_dim_
:
6
,
// bit [17,23) : 3 LSBs not included
:
1
,
// padding
m_dim_
:
5
,
// bit [24,29) : 4 LSBs not included
:
1
,
// padding
max_shift_
:
2
;
// bit [30,32)
}
bitfield
;
// Decay to a uint32_t
CUTE_HOST_DEVICE
constexpr
explicit
operator
uint32_t
()
const
noexcept
{
return
desc_
;
}
};
// Any
// Any
template
<
typename
T
>
TL_DEVICE
bool
Any
(
T
*
a
,
int
size
)
{
template
<
typename
T
>
TL_DEVICE
bool
Any
(
T
*
a
,
int
size
)
{
for
(
int
i
=
0
;
i
<
size
;
i
++
)
{
for
(
int
i
=
0
;
i
<
size
;
i
++
)
{
...
@@ -323,8 +458,8 @@ TL_DEVICE void __sync_thread_partial() {
...
@@ -323,8 +458,8 @@ TL_DEVICE void __sync_thread_partial() {
template
<
int
layout_type
=
0
,
int
leading_byte_offset
=
0
,
template
<
int
layout_type
=
0
,
int
leading_byte_offset
=
0
,
int
stride_byte_offset
=
0
,
typename
T
>
int
stride_byte_offset
=
0
,
typename
T
>
TL_DEVICE
void
initialize_descriptor
(
GmmaDescriptor
&
descriptor
,
TL_DEVICE
void
initialize_
wgmma_
descriptor
(
GmmaDescriptor
&
descriptor
,
T
*
start_address
)
{
T
*
start_address
)
{
descriptor
.
bitfield
.
start_address_
=
descriptor
.
bitfield
.
start_address_
=
cute
::
cast_smem_ptr_to_uint
(
start_address
)
>>
4
;
cute
::
cast_smem_ptr_to_uint
(
start_address
)
>>
4
;
descriptor
.
bitfield
.
layout_type_
=
layout_type
;
descriptor
.
bitfield
.
layout_type_
=
layout_type
;
...
@@ -333,15 +468,151 @@ TL_DEVICE void initialize_descriptor(GmmaDescriptor &descriptor,
...
@@ -333,15 +468,151 @@ TL_DEVICE void initialize_descriptor(GmmaDescriptor &descriptor,
descriptor
.
bitfield
.
stride_byte_offset_
=
stride_byte_offset
;
descriptor
.
bitfield
.
stride_byte_offset_
=
stride_byte_offset
;
}
}
template
<
typename
T
>
TL_DEVICE
void
initialize_tcgen05_descriptor
(
Tcgen05SMemDescriptor
&
descriptor
,
T
*
start_address
,
int
leading_byte_offset
,
int
stride_byte_offset
,
int
base_offset
,
bool
leading_is_absolute
,
int
swizzle_mode
)
{
descriptor
.
bitfield
.
start_address_
=
static_cast
<
uint16_t
>
(
cast_smem_ptr_to_uint
(
start_address
)
>>
4
);
descriptor
.
bitfield
.
leading_byte_offset_
=
leading_byte_offset
;
descriptor
.
bitfield
.
stride_byte_offset_
=
stride_byte_offset
;
descriptor
.
bitfield
.
version_
=
1
;
descriptor
.
bitfield
.
base_offset_
=
base_offset
&
0x7
;
descriptor
.
bitfield
.
lbo_mode_
=
leading_is_absolute
?
1
:
0
;
descriptor
.
bitfield
.
layout_type_
=
swizzle_mode
&
0x7
;
}
template
<
typename
T
>
template
<
typename
T
>
TL_DEVICE
void
increase_descriptor_offset
(
GmmaDescriptor
&
descriptor
,
TL_DEVICE
void
increase_descriptor_offset
(
GmmaDescriptor
&
descriptor
,
T
offset
)
{
T
offset
)
{
descriptor
.
reg32_
[
0
]
+=
(
offset
>>
4
);
descriptor
.
reg32_
[
0
]
+=
(
offset
>>
4
);
}
}
// and add the desired implicit conversion from bfloat16_t.
struct
float_e4m3_t
:
public
cute
::
float_e4m3_t
{
using
cute
::
float_e4m3_t
::
float_e4m3_t
;
CUTLASS_HOST_DEVICE
float_e4m3_t
()
=
default
;
CUTLASS_HOST_DEVICE
explicit
float_e4m3_t
(
__nv_bfloat16
x
)
:
float_e4m3_t
(
static_cast
<
float
>
(
x
))
{}
};
struct
float_e5m2_t
:
public
cute
::
float_e5m2_t
{
using
cute
::
float_e5m2_t
::
float_e5m2_t
;
CUTLASS_HOST_DEVICE
float_e5m2_t
()
=
default
;
CUTLASS_HOST_DEVICE
explicit
float_e5m2_t
(
__nv_bfloat16
x
)
:
float_e5m2_t
(
static_cast
<
float
>
(
x
))
{}
};
template
<
typename
T
>
struct
to_cute_type
{
using
type
=
T
;
};
template
<
>
struct
to_cute_type
<
tl
::
float_e4m3_t
>
{
using
type
=
cute
::
float_e4m3_t
;
};
template
<
>
struct
to_cute_type
<
tl
::
float_e5m2_t
>
{
using
type
=
cute
::
float_e5m2_t
;
};
}
// namespace tl
}
// namespace tl
namespace
cutlass
{
namespace
cutlass
{
TL_DEVICE
TL_DEVICE
bfloat16_t
fast_exp
(
bfloat16_t
x
)
{
return
::
hexp
(
x
);
}
bfloat16_t
fast_exp
(
bfloat16_t
x
)
{
return
::
hexp
(
x
);
}
}
// namespace cutlass
}
// namespace cutlass
//
// Type-safe warp shuffle helpers for 16-bit float types
// These wrappers avoid relying on implicit conversions that may be disallowed
// (e.g., converting float -> cutlass::bfloat16_t) by explicitly promoting to
// float for the shuffle and then down-converting.
//
namespace
tl
{
// Generic passthroughs
template
<
typename
T
>
TL_DEVICE
T
shfl_xor_sync
(
unsigned
mask
,
T
val
,
int
laneMask
)
{
return
__shfl_xor_sync
(
mask
,
val
,
laneMask
);
}
template
<
typename
T
>
TL_DEVICE
T
shfl_down_sync
(
unsigned
mask
,
T
val
,
int
delta
)
{
return
__shfl_down_sync
(
mask
,
val
,
delta
);
}
template
<
typename
T
>
TL_DEVICE
T
shfl_up_sync
(
unsigned
mask
,
T
val
,
int
delta
)
{
return
__shfl_up_sync
(
mask
,
val
,
delta
);
}
template
<
typename
T
>
TL_DEVICE
T
shfl_sync
(
unsigned
mask
,
T
val
,
int
srcLane
)
{
return
__shfl_sync
(
mask
,
val
,
srcLane
);
}
// Specializations for cutlass::half_t
template
<
>
TL_DEVICE
half_t
shfl_xor_sync
(
unsigned
mask
,
half_t
val
,
int
laneMask
)
{
float
f
=
static_cast
<
float
>
(
val
);
float
r
=
__shfl_xor_sync
(
mask
,
f
,
laneMask
);
return
half_t
(
r
);
}
template
<
>
TL_DEVICE
half_t
shfl_down_sync
(
unsigned
mask
,
half_t
val
,
int
delta
)
{
float
f
=
static_cast
<
float
>
(
val
);
float
r
=
__shfl_down_sync
(
mask
,
f
,
delta
);
return
half_t
(
r
);
}
template
<
>
TL_DEVICE
half_t
shfl_up_sync
(
unsigned
mask
,
half_t
val
,
int
delta
)
{
float
f
=
static_cast
<
float
>
(
val
);
float
r
=
__shfl_up_sync
(
mask
,
f
,
delta
);
return
half_t
(
r
);
}
template
<
>
TL_DEVICE
half_t
shfl_sync
(
unsigned
mask
,
half_t
val
,
int
srcLane
)
{
float
f
=
static_cast
<
float
>
(
val
);
float
r
=
__shfl_sync
(
mask
,
f
,
srcLane
);
return
half_t
(
r
);
}
// Specializations for cutlass::bfloat16_t
template
<
>
TL_DEVICE
bfloat16_t
shfl_xor_sync
(
unsigned
mask
,
bfloat16_t
val
,
int
laneMask
)
{
float
f
=
static_cast
<
float
>
(
val
);
float
r
=
__shfl_xor_sync
(
mask
,
f
,
laneMask
);
return
bfloat16_t
(
r
);
}
template
<
>
TL_DEVICE
bfloat16_t
shfl_down_sync
(
unsigned
mask
,
bfloat16_t
val
,
int
delta
)
{
float
f
=
static_cast
<
float
>
(
val
);
float
r
=
__shfl_down_sync
(
mask
,
f
,
delta
);
return
bfloat16_t
(
r
);
}
template
<
>
TL_DEVICE
bfloat16_t
shfl_up_sync
(
unsigned
mask
,
bfloat16_t
val
,
int
delta
)
{
float
f
=
static_cast
<
float
>
(
val
);
float
r
=
__shfl_up_sync
(
mask
,
f
,
delta
);
return
bfloat16_t
(
r
);
}
template
<
>
TL_DEVICE
bfloat16_t
shfl_sync
(
unsigned
mask
,
bfloat16_t
val
,
int
srcLane
)
{
float
f
=
static_cast
<
float
>
(
val
);
float
r
=
__shfl_sync
(
mask
,
f
,
srcLane
);
return
bfloat16_t
(
r
);
}
}
// namespace tl
src/tl_templates/cuda/cuda_fp8.h
View file @
bbbf4207
#pragma once
#pragma once
#include "common.h"
#include <cuda_fp8.h>
#include <cuda_fp8.h>
#include <cute/numeric/numeric_types.hpp>
#include <cute/numeric/numeric_types.hpp>
using
fp8_e4_t
=
cute
::
float_e4m3_t
;
using
fp8_e4_t
=
tl
::
float_e4m3_t
;
using
fp8_e5_t
=
cute
::
float_e5m2_t
;
using
fp8_e5_t
=
tl
::
float_e5m2_t
;
struct
__CUDA_ALIGN__
(
2
)
fp8_e4_2_t
{
struct
__CUDA_ALIGN__
(
2
)
fp8_e4_2_t
{
fp8_e4_t
x
;
fp8_e4_t
x
;
...
...
src/tl_templates/cuda/gemm_mma.h
View file @
bbbf4207
...
@@ -263,16 +263,18 @@ template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
...
@@ -263,16 +263,18 @@ template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
typename
C_type_raw
>
typename
C_type_raw
>
class
GemmTensorOp
{
class
GemmTensorOp
{
public:
public:
using
A_type_cute
=
typename
tl
::
to_cute_type
<
A_type_raw
>::
type
;
using
B_type_cute
=
typename
tl
::
to_cute_type
<
B_type_raw
>::
type
;
using
A_type
=
using
A_type
=
typename
std
::
conditional
<
std
::
is_same
<
A_type_
raw
,
float
>::
value
,
typename
std
::
conditional
<
std
::
is_same
<
A_type_
cute
,
float
>::
value
,
tfloat32_t
,
A_type_
raw
>::
type
;
tfloat32_t
,
A_type_
cute
>::
type
;
using
B_type
=
using
B_type
=
typename
std
::
conditional
<
std
::
is_same
<
B_type_
raw
,
float
>::
value
,
typename
std
::
conditional
<
std
::
is_same
<
B_type_
cute
,
float
>::
value
,
tfloat32_t
,
A
_type_
raw
>::
type
;
tfloat32_t
,
B
_type_
cute
>::
type
;
using
C_type
=
C_type_raw
;
using
C_type
=
C_type_raw
;
using
Instruction
=
using
Instruction
=
DispatchInstruction
<
A_type_raw
,
B_type_raw
,
C_type_raw
,
DispatchInstruction
<
A_type
,
B_type
,
C_type
,
num_warp_m
,
num_warp_n
,
N
>
;
num_warp_m
,
num_warp_n
,
N
>
;
using
OperandATraits
=
OperandTraits
<
sizeof_bits
<
A_type
>::
value
,
M
,
K
,
using
OperandATraits
=
OperandTraits
<
sizeof_bits
<
A_type
>::
value
,
M
,
K
,
!
trans_A
,
num_warp_m
,
lda
>
;
!
trans_A
,
num_warp_m
,
lda
>
;
...
...
src/tl_templates/cuda/gemm_sm100.h
View file @
bbbf4207
...
@@ -289,12 +289,14 @@ template <int M, int N, int K, int AtomM, int AtomN, int AtomK, bool trans_A,
...
@@ -289,12 +289,14 @@ template <int M, int N, int K, int AtomM, int AtomN, int AtomK, bool trans_A,
typename
C_type_raw
>
typename
C_type_raw
>
class
GemmTensorOp
{
class
GemmTensorOp
{
public:
public:
using
A_type_cute
=
typename
tl
::
to_cute_type
<
A_type_raw
>::
type
;
using
B_type_cute
=
typename
tl
::
to_cute_type
<
B_type_raw
>::
type
;
using
A_type
=
using
A_type
=
typename
std
::
conditional
<
std
::
is_same
<
A_type_
raw
,
float
>::
value
,
typename
std
::
conditional
<
std
::
is_same
<
A_type_
cute
,
float
>::
value
,
tfloat32_t
,
A_type_
raw
>::
type
;
tfloat32_t
,
A_type_
cute
>::
type
;
using
B_type
=
using
B_type
=
typename
std
::
conditional
<
std
::
is_same
<
B_type_
raw
,
float
>::
value
,
typename
std
::
conditional
<
std
::
is_same
<
B_type_
cute
,
float
>::
value
,
tfloat32_t
,
B_type_
raw
>::
type
;
tfloat32_t
,
B_type_
cute
>::
type
;
using
C_type
=
C_type_raw
;
using
C_type
=
C_type_raw
;
static_assert
(
AtomM
==
128
||
AtomM
==
64
||
AtomM
==
32
);
static_assert
(
AtomM
==
128
||
AtomM
==
64
||
AtomM
==
32
);
...
...
src/tl_templates/cuda/gemm_sm90.h
View file @
bbbf4207
...
@@ -21,10 +21,12 @@ template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
...
@@ -21,10 +21,12 @@ template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
typename
B_type_raw
,
typename
C_type_raw
>
typename
B_type_raw
,
typename
C_type_raw
>
class
GemmTensorOp
{
class
GemmTensorOp
{
public:
public:
using
A_type
=
conditional_t
<
std
::
is_same
<
A_type_raw
,
float
>::
value
,
using
A_type_cute
=
typename
tl
::
to_cute_type
<
A_type_raw
>::
type
;
tfloat32_t
,
A_type_raw
>
;
using
B_type_cute
=
typename
tl
::
to_cute_type
<
B_type_raw
>::
type
;
using
B_type
=
conditional_t
<
std
::
is_same
<
B_type_raw
,
float
>::
value
,
using
A_type
=
conditional_t
<
std
::
is_same
<
A_type_cute
,
float
>::
value
,
tfloat32_t
,
B_type_raw
>
;
tfloat32_t
,
A_type_cute
>
;
using
B_type
=
conditional_t
<
std
::
is_same
<
B_type_cute
,
float
>::
value
,
tfloat32_t
,
A_type_cute
>
;
using
C_type
=
C_type_raw
;
using
C_type
=
C_type_raw
;
static
constexpr
GMMA
::
Major
GmmaMajorA
=
static
constexpr
GMMA
::
Major
GmmaMajorA
=
...
...
src/tl_templates/cuda/gemm_sp_sm90.h
View file @
bbbf4207
...
@@ -13,10 +13,12 @@ class GemmTensorOp {
...
@@ -13,10 +13,12 @@ class GemmTensorOp {
public:
public:
static_assert
(
num_warp_m
%
4
==
0
,
"num_warp_m must be a multiple of 4"
);
static_assert
(
num_warp_m
%
4
==
0
,
"num_warp_m must be a multiple of 4"
);
using
A_type
=
conditional_t
<
std
::
is_same
<
A_type_raw
,
float
>::
value
,
using
A_type_cute
=
typename
tl
::
to_cute_type
<
A_type_raw
>::
type
;
tfloat32_t
,
A_type_raw
>
;
using
B_type_cute
=
typename
tl
::
to_cute_type
<
B_type_raw
>::
type
;
using
B_type
=
conditional_t
<
std
::
is_same
<
B_type_raw
,
float
>::
value
,
using
A_type
=
conditional_t
<
std
::
is_same
<
A_type_cute
,
float
>::
value
,
tfloat32_t
,
B_type_raw
>
;
tfloat32_t
,
A_type_cute
>
;
using
B_type
=
conditional_t
<
std
::
is_same
<
B_type_cute
,
float
>::
value
,
tfloat32_t
,
B_type_cute
>
;
using
C_type
=
C_type_raw
;
using
C_type
=
C_type_raw
;
static
constexpr
bool
need_tfloat32_cast
=
static
constexpr
bool
need_tfloat32_cast
=
...
...
src/tl_templates/cuda/instruction/mma.h
0 → 100644
View file @
bbbf4207
#pragma once
#include "../common.h"
#include <cute/arch/mma_sm80.hpp>
#include <cute/arch/mma_sm89.hpp>
#include <type_traits>
#include <utility>
namespace
tl
{
#ifndef TL_ALWAYS_FALSE_V_DEFINED
#define TL_ALWAYS_FALSE_V_DEFINED
template
<
class
>
inline
constexpr
bool
always_false_v
=
false
;
#endif
namespace
detail
{
template
<
class
Impl
>
struct
MmaImplTraits
{
using
DReg
=
std
::
remove_extent_t
<
typename
Impl
::
DRegisters
>
;
using
AReg
=
std
::
remove_extent_t
<
typename
Impl
::
ARegisters
>
;
using
BReg
=
std
::
remove_extent_t
<
typename
Impl
::
BRegisters
>
;
using
CReg
=
std
::
remove_extent_t
<
typename
Impl
::
CRegisters
>
;
static
constexpr
int
kDRegs
=
std
::
extent_v
<
typename
Impl
::
DRegisters
>
;
static
constexpr
int
kARegs
=
std
::
extent_v
<
typename
Impl
::
ARegisters
>
;
static
constexpr
int
kBRegs
=
std
::
extent_v
<
typename
Impl
::
BRegisters
>
;
static
constexpr
int
kCRegs
=
std
::
extent_v
<
typename
Impl
::
CRegisters
>
;
};
template
<
class
Impl
,
size_t
...
DIdx
,
size_t
...
AIdx
,
size_t
...
BIdx
,
size_t
...
CIdx
>
TL_DEVICE
void
call_fma_impl
(
typename
MmaImplTraits
<
Impl
>::
DReg
*
d
,
const
typename
MmaImplTraits
<
Impl
>::
AReg
*
a
,
const
typename
MmaImplTraits
<
Impl
>::
BReg
*
b
,
const
typename
MmaImplTraits
<
Impl
>::
CReg
*
c
,
std
::
index_sequence
<
DIdx
...
>
,
std
::
index_sequence
<
AIdx
...
>
,
std
::
index_sequence
<
BIdx
...
>
,
std
::
index_sequence
<
CIdx
...
>
)
{
Impl
::
fma
(
d
[
DIdx
]...,
a
[
AIdx
]...,
b
[
BIdx
]...,
c
[
CIdx
]...);
}
template
<
class
Impl
>
TL_DEVICE
void
call_fma
(
typename
MmaImplTraits
<
Impl
>::
DReg
*
d
,
const
typename
MmaImplTraits
<
Impl
>::
AReg
*
a
,
const
typename
MmaImplTraits
<
Impl
>::
BReg
*
b
,
const
typename
MmaImplTraits
<
Impl
>::
CReg
*
c
)
{
call_fma_impl
<
Impl
>
(
d
,
a
,
b
,
c
,
std
::
make_index_sequence
<
MmaImplTraits
<
Impl
>::
kDRegs
>
{},
std
::
make_index_sequence
<
MmaImplTraits
<
Impl
>::
kARegs
>
{},
std
::
make_index_sequence
<
MmaImplTraits
<
Impl
>::
kBRegs
>
{},
std
::
make_index_sequence
<
MmaImplTraits
<
Impl
>::
kCRegs
>
{});
}
template
<
DataType
AType
,
DataType
BType
,
DataType
CType
,
int
M
,
int
N
,
int
K
,
bool
TransA
,
bool
TransB
,
bool
Saturate
>
struct
MmaDispatcher
{
using
CRegType
=
void
;
using
ARegType
=
void
;
using
BRegType
=
void
;
static
TL_DEVICE
void
exec
(
CRegType
*
,
const
ARegType
*
,
const
BRegType
*
,
const
CRegType
*
)
{
static_assert
(
always_false_v
<
std
::
integral_constant
<
int
,
M
>>
,
"tl::mma_sync: unsupported configuration"
);
}
};
#define TL_DEFINE_MMA_DISPATCHER(ATypeEnum, BTypeEnum, CTypeEnum, MValue, \
NValue, KValue, TransAValue, TransBValue, \
SaturateValue, ImplType) \
template <> \
struct MmaDispatcher<DataType::ATypeEnum, DataType::BTypeEnum, \
DataType::CTypeEnum, MValue, NValue, KValue, \
TransAValue, TransBValue, SaturateValue> { \
using Impl = ImplType; \
using Traits = MmaImplTraits<Impl>; \
using CRegType = typename Traits::DReg; \
using ARegType = typename Traits::AReg; \
using BRegType = typename Traits::BReg; \
static_assert( \
std::is_same_v<typename Traits::DReg, typename Traits::CReg>, \
"tl::mma_sync requires matching accumulator/output regs"); \
static TL_DEVICE void exec(CRegType *d, const ARegType *a, \
const BRegType *b, const CRegType *c) { \
call_fma<Impl>(d, a, b, c); \
} \
};
// FP16 inputs (TN layout: A row-major, B column-major)
TL_DEFINE_MMA_DISPATCHER
(
kFloat16
,
kFloat16
,
kFloat16
,
16
,
8
,
16
,
false
,
true
,
false
,
cute
::
SM80_16x8x16_F16F16F16F16_TN
)
TL_DEFINE_MMA_DISPATCHER
(
kFloat16
,
kFloat16
,
kFloat32
,
16
,
8
,
16
,
false
,
true
,
false
,
cute
::
SM80_16x8x16_F32F16F16F32_TN
)
// BF16 inputs
TL_DEFINE_MMA_DISPATCHER
(
kBFloat16
,
kBFloat16
,
kFloat32
,
16
,
8
,
16
,
false
,
true
,
false
,
cute
::
SM80_16x8x16_F32BF16BF16F32_TN
)
// INT8 inputs (k32)
TL_DEFINE_MMA_DISPATCHER
(
kInt8
,
kInt8
,
kInt32
,
16
,
8
,
32
,
false
,
true
,
false
,
cute
::
SM80_16x8x32_S32S8S8S32_TN
)
TL_DEFINE_MMA_DISPATCHER
(
kUInt8
,
kUInt8
,
kInt32
,
16
,
8
,
32
,
false
,
true
,
false
,
cute
::
SM80_16x8x32_S32U8U8S32_TN
)
// INT4 inputs (k32)
TL_DEFINE_MMA_DISPATCHER
(
kInt4
,
kInt4
,
kInt32
,
16
,
8
,
32
,
false
,
true
,
false
,
cute
::
SM80_16x8x32_S32S4S4S32_TN
)
TL_DEFINE_MMA_DISPATCHER
(
kUInt4
,
kUInt4
,
kInt32
,
16
,
8
,
32
,
false
,
true
,
false
,
cute
::
SM80_16x8x32_S32U4U4S32_TN
)
// FP8 inputs (k32)
TL_DEFINE_MMA_DISPATCHER
(
kFloat8_e4m3
,
kFloat8_e4m3
,
kFloat16
,
16
,
8
,
32
,
false
,
true
,
false
,
cute
::
SM89_16x8x32_F16E4M3E4M3F16_TN
)
TL_DEFINE_MMA_DISPATCHER
(
kFloat8_e4m3
,
kFloat8_e4m3
,
kFloat32
,
16
,
8
,
32
,
false
,
true
,
false
,
cute
::
SM89_16x8x32_F32E4M3E4M3F32_TN
)
TL_DEFINE_MMA_DISPATCHER
(
kFloat8_e4m3
,
kFloat8_e5m2
,
kFloat16
,
16
,
8
,
32
,
false
,
true
,
false
,
cute
::
SM89_16x8x32_F16E4M3E5M2F16_TN
)
TL_DEFINE_MMA_DISPATCHER
(
kFloat8_e4m3
,
kFloat8_e5m2
,
kFloat32
,
16
,
8
,
32
,
false
,
true
,
false
,
cute
::
SM89_16x8x32_F32E4M3E5M2F32_TN
)
TL_DEFINE_MMA_DISPATCHER
(
kFloat8_e5m2
,
kFloat8_e4m3
,
kFloat16
,
16
,
8
,
32
,
false
,
true
,
false
,
cute
::
SM89_16x8x32_F16E5M2E4M3F16_TN
)
TL_DEFINE_MMA_DISPATCHER
(
kFloat8_e5m2
,
kFloat8_e4m3
,
kFloat32
,
16
,
8
,
32
,
false
,
true
,
false
,
cute
::
SM89_16x8x32_F32E5M2E4M3F32_TN
)
TL_DEFINE_MMA_DISPATCHER
(
kFloat8_e5m2
,
kFloat8_e5m2
,
kFloat16
,
16
,
8
,
32
,
false
,
true
,
false
,
cute
::
SM89_16x8x32_F16E5M2E5M2F16_TN
)
TL_DEFINE_MMA_DISPATCHER
(
kFloat8_e5m2
,
kFloat8_e5m2
,
kFloat32
,
16
,
8
,
32
,
false
,
true
,
false
,
cute
::
SM89_16x8x32_F32E5M2E5M2F32_TN
)
// TF32 inputs (FP32 math on Tensor Cores)
// Support both k=4 and k=8 variants on SM80
TL_DEFINE_MMA_DISPATCHER
(
kTensorFloat32
,
kTensorFloat32
,
kFloat32
,
16
,
8
,
4
,
false
,
true
,
false
,
cute
::
SM80_16x8x4_F32TF32TF32F32_TN
)
TL_DEFINE_MMA_DISPATCHER
(
kTensorFloat32
,
kTensorFloat32
,
kFloat32
,
16
,
8
,
8
,
false
,
true
,
false
,
cute
::
SM80_16x8x8_F32TF32TF32F32_TN
)
// FP64 inputs (DMMA: m8n8k4, TN layout)
TL_DEFINE_MMA_DISPATCHER
(
kFloat64
,
kFloat64
,
kFloat64
,
8
,
8
,
4
,
false
,
true
,
false
,
cute
::
SM80_8x8x4_F64F64F64F64_TN
)
#undef TL_DEFINE_MMA_DISPATCHER
}
// namespace detail
template
<
DataType
AType
,
DataType
BType
,
DataType
CType
,
int
M
,
int
N
,
int
K
,
bool
TransA
,
bool
TransB
,
bool
Saturate
=
false
>
TL_DEVICE
void
mma_sync
(
typename
detail
::
MmaDispatcher
<
AType
,
BType
,
CType
,
M
,
N
,
K
,
TransA
,
TransB
,
Saturate
>::
CRegType
*
c
,
const
typename
detail
::
MmaDispatcher
<
AType
,
BType
,
CType
,
M
,
N
,
K
,
TransA
,
TransB
,
Saturate
>::
ARegType
*
a
,
const
typename
detail
::
MmaDispatcher
<
AType
,
BType
,
CType
,
M
,
N
,
K
,
TransA
,
TransB
,
Saturate
>::
BRegType
*
b
)
{
using
Dispatcher
=
detail
::
MmaDispatcher
<
AType
,
BType
,
CType
,
M
,
N
,
K
,
TransA
,
TransB
,
Saturate
>
;
static_assert
(
!
std
::
is_void_v
<
typename
Dispatcher
::
CRegType
>
,
"tl::mma_sync: unsupported configuration"
);
Dispatcher
::
exec
(
c
,
a
,
b
,
c
);
}
}
// namespace tl
src/tl_templates/cuda/instruction/mma_sm70.h
0 → 100644
View file @
bbbf4207
#pragma once
#include "../common.h"
#include <type_traits>
#include <utility>
namespace
tl
{
#ifndef TL_ALWAYS_FALSE_V_DEFINED
#define TL_ALWAYS_FALSE_V_DEFINED
template
<
class
>
inline
constexpr
bool
always_false_v
=
false
;
#endif
namespace
detail
{
// SM70 MMA Instruction Traits and Implementations
// SM70 supports m16n16k4 (m8n8k4 instruction at warp level) with FP16/FP32
// accumulation
// Base template for SM70 MMA implementation
template
<
DataType
AType
,
DataType
BType
,
DataType
CType
,
bool
TransA
,
bool
TransB
>
struct
MmaSm70Impl
{
// Default: unsupported configuration
static
constexpr
bool
kSupported
=
false
;
static
TL_DEVICE
void
exec
(
void
*
,
const
void
*
,
const
void
*
,
const
void
*
)
{
static_assert
(
always_false_v
<
std
::
integral_constant
<
bool
,
TransA
>>
,
"tl::mma_sync_sm70: unsupported configuration"
);
}
};
// FP16 inputs, FP16 accumulation - col.col (TransA=true, TransB=true)
template
<
>
struct
MmaSm70Impl
<
DataType
::
kFloat16
,
DataType
::
kFloat16
,
DataType
::
kFloat16
,
true
,
true
>
{
using
DRegisters
=
unsigned
[
4
];
using
ARegisters
=
unsigned
[
2
];
using
BRegisters
=
unsigned
[
2
];
using
CRegisters
=
unsigned
[
4
];
static
constexpr
bool
kSupported
=
true
;
static
TL_DEVICE
void
fma
(
unsigned
&
d0
,
unsigned
&
d1
,
unsigned
&
d2
,
unsigned
&
d3
,
unsigned
a0
,
unsigned
a1
,
unsigned
b0
,
unsigned
b1
,
unsigned
c0
,
unsigned
c1
,
unsigned
c2
,
unsigned
c3
)
{
asm
volatile
(
"mma.sync.aligned.m8n8k4.col.col.f16.f16.f16.f16 "
"{%0,%1,%2,%3}, {%4,%5}, {%6,%7}, {%8,%9,%10,%11};
\n
"
:
"=r"
(
d0
),
"=r"
(
d1
),
"=r"
(
d2
),
"=r"
(
d3
)
:
"r"
(
a0
),
"r"
(
a1
),
"r"
(
b0
),
"r"
(
b1
),
"r"
(
c0
),
"r"
(
c1
),
"r"
(
c2
),
"r"
(
c3
));
}
};
// FP16 inputs, FP16 accumulation - col.row (TransA=true, TransB=false)
template
<
>
struct
MmaSm70Impl
<
DataType
::
kFloat16
,
DataType
::
kFloat16
,
DataType
::
kFloat16
,
true
,
false
>
{
using
DRegisters
=
unsigned
[
4
];
using
ARegisters
=
unsigned
[
2
];
using
BRegisters
=
unsigned
[
2
];
using
CRegisters
=
unsigned
[
4
];
static
constexpr
bool
kSupported
=
true
;
static
TL_DEVICE
void
fma
(
unsigned
&
d0
,
unsigned
&
d1
,
unsigned
&
d2
,
unsigned
&
d3
,
unsigned
a0
,
unsigned
a1
,
unsigned
b0
,
unsigned
b1
,
unsigned
c0
,
unsigned
c1
,
unsigned
c2
,
unsigned
c3
)
{
asm
volatile
(
"mma.sync.aligned.m8n8k4.col.row.f16.f16.f16.f16 "
"{%0,%1,%2,%3}, {%4,%5}, {%6,%7}, {%8,%9,%10,%11};
\n
"
:
"=r"
(
d0
),
"=r"
(
d1
),
"=r"
(
d2
),
"=r"
(
d3
)
:
"r"
(
a0
),
"r"
(
a1
),
"r"
(
b0
),
"r"
(
b1
),
"r"
(
c0
),
"r"
(
c1
),
"r"
(
c2
),
"r"
(
c3
));
}
};
// FP16 inputs, FP16 accumulation - row.col (TransA=false, TransB=true)
template
<
>
struct
MmaSm70Impl
<
DataType
::
kFloat16
,
DataType
::
kFloat16
,
DataType
::
kFloat16
,
false
,
true
>
{
using
DRegisters
=
unsigned
[
4
];
using
ARegisters
=
unsigned
[
2
];
using
BRegisters
=
unsigned
[
2
];
using
CRegisters
=
unsigned
[
4
];
static
constexpr
bool
kSupported
=
true
;
static
TL_DEVICE
void
fma
(
unsigned
&
d0
,
unsigned
&
d1
,
unsigned
&
d2
,
unsigned
&
d3
,
unsigned
a0
,
unsigned
a1
,
unsigned
b0
,
unsigned
b1
,
unsigned
c0
,
unsigned
c1
,
unsigned
c2
,
unsigned
c3
)
{
asm
volatile
(
"mma.sync.aligned.m8n8k4.row.col.f16.f16.f16.f16 "
"{%0,%1,%2,%3}, {%4,%5}, {%6,%7}, {%8,%9,%10,%11};
\n
"
:
"=r"
(
d0
),
"=r"
(
d1
),
"=r"
(
d2
),
"=r"
(
d3
)
:
"r"
(
a0
),
"r"
(
a1
),
"r"
(
b0
),
"r"
(
b1
),
"r"
(
c0
),
"r"
(
c1
),
"r"
(
c2
),
"r"
(
c3
));
}
};
// FP16 inputs, FP16 accumulation - row.row (TransA=false, TransB=false)
template
<
>
struct
MmaSm70Impl
<
DataType
::
kFloat16
,
DataType
::
kFloat16
,
DataType
::
kFloat16
,
false
,
false
>
{
using
DRegisters
=
unsigned
[
4
];
using
ARegisters
=
unsigned
[
2
];
using
BRegisters
=
unsigned
[
2
];
using
CRegisters
=
unsigned
[
4
];
static
constexpr
bool
kSupported
=
true
;
static
TL_DEVICE
void
fma
(
unsigned
&
d0
,
unsigned
&
d1
,
unsigned
&
d2
,
unsigned
&
d3
,
unsigned
a0
,
unsigned
a1
,
unsigned
b0
,
unsigned
b1
,
unsigned
c0
,
unsigned
c1
,
unsigned
c2
,
unsigned
c3
)
{
asm
volatile
(
"mma.sync.aligned.m8n8k4.row.row.f16.f16.f16.f16 "
"{%0,%1,%2,%3}, {%4,%5}, {%6,%7}, {%8,%9,%10,%11};
\n
"
:
"=r"
(
d0
),
"=r"
(
d1
),
"=r"
(
d2
),
"=r"
(
d3
)
:
"r"
(
a0
),
"r"
(
a1
),
"r"
(
b0
),
"r"
(
b1
),
"r"
(
c0
),
"r"
(
c1
),
"r"
(
c2
),
"r"
(
c3
));
}
};
// FP16 inputs, FP32 accumulation - col.col (TransA=true, TransB=true)
template
<
>
struct
MmaSm70Impl
<
DataType
::
kFloat16
,
DataType
::
kFloat16
,
DataType
::
kFloat32
,
true
,
true
>
{
using
DRegisters
=
float
[
8
];
using
ARegisters
=
unsigned
[
2
];
using
BRegisters
=
unsigned
[
2
];
using
CRegisters
=
float
[
8
];
static
constexpr
bool
kSupported
=
true
;
static
TL_DEVICE
void
fma
(
float
&
d0
,
float
&
d1
,
float
&
d2
,
float
&
d3
,
float
&
d4
,
float
&
d5
,
float
&
d6
,
float
&
d7
,
unsigned
a0
,
unsigned
a1
,
unsigned
b0
,
unsigned
b1
,
float
c0
,
float
c1
,
float
c2
,
float
c3
,
float
c4
,
float
c5
,
float
c6
,
float
c7
)
{
asm
volatile
(
"mma.sync.aligned.m8n8k4.col.col.f32.f16.f16.f32 "
"{%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, "
"{%12,%13,%14,%15,%16,%17,%18,%19};
\n
"
:
"=f"
(
d0
),
"=f"
(
d1
),
"=f"
(
d2
),
"=f"
(
d3
),
"=f"
(
d4
),
"=f"
(
d5
),
"=f"
(
d6
),
"=f"
(
d7
)
:
"r"
(
a0
),
"r"
(
a1
),
"r"
(
b0
),
"r"
(
b1
),
"f"
(
c0
),
"f"
(
c1
),
"f"
(
c2
),
"f"
(
c3
),
"f"
(
c4
),
"f"
(
c5
),
"f"
(
c6
),
"f"
(
c7
));
}
};
// FP16 inputs, FP32 accumulation - col.row (TransA=true, TransB=false)
template
<
>
struct
MmaSm70Impl
<
DataType
::
kFloat16
,
DataType
::
kFloat16
,
DataType
::
kFloat32
,
true
,
false
>
{
using
DRegisters
=
float
[
8
];
using
ARegisters
=
unsigned
[
2
];
using
BRegisters
=
unsigned
[
2
];
using
CRegisters
=
float
[
8
];
static
constexpr
bool
kSupported
=
true
;
static
TL_DEVICE
void
fma
(
float
&
d0
,
float
&
d1
,
float
&
d2
,
float
&
d3
,
float
&
d4
,
float
&
d5
,
float
&
d6
,
float
&
d7
,
unsigned
a0
,
unsigned
a1
,
unsigned
b0
,
unsigned
b1
,
float
c0
,
float
c1
,
float
c2
,
float
c3
,
float
c4
,
float
c5
,
float
c6
,
float
c7
)
{
asm
volatile
(
"mma.sync.aligned.m8n8k4.col.row.f32.f16.f16.f32 "
"{%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, "
"{%12,%13,%14,%15,%16,%17,%18,%19};
\n
"
:
"=f"
(
d0
),
"=f"
(
d1
),
"=f"
(
d2
),
"=f"
(
d3
),
"=f"
(
d4
),
"=f"
(
d5
),
"=f"
(
d6
),
"=f"
(
d7
)
:
"r"
(
a0
),
"r"
(
a1
),
"r"
(
b0
),
"r"
(
b1
),
"f"
(
c0
),
"f"
(
c1
),
"f"
(
c2
),
"f"
(
c3
),
"f"
(
c4
),
"f"
(
c5
),
"f"
(
c6
),
"f"
(
c7
));
}
};
// FP16 inputs, FP32 accumulation - row.col (TransA=false, TransB=true)
template
<
>
struct
MmaSm70Impl
<
DataType
::
kFloat16
,
DataType
::
kFloat16
,
DataType
::
kFloat32
,
false
,
true
>
{
using
DRegisters
=
float
[
8
];
using
ARegisters
=
unsigned
[
2
];
using
BRegisters
=
unsigned
[
2
];
using
CRegisters
=
float
[
8
];
static
constexpr
bool
kSupported
=
true
;
static
TL_DEVICE
void
fma
(
float
&
d0
,
float
&
d1
,
float
&
d2
,
float
&
d3
,
float
&
d4
,
float
&
d5
,
float
&
d6
,
float
&
d7
,
unsigned
a0
,
unsigned
a1
,
unsigned
b0
,
unsigned
b1
,
float
c0
,
float
c1
,
float
c2
,
float
c3
,
float
c4
,
float
c5
,
float
c6
,
float
c7
)
{
asm
volatile
(
"mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 "
"{%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, "
"{%12,%13,%14,%15,%16,%17,%18,%19};
\n
"
:
"=f"
(
d0
),
"=f"
(
d1
),
"=f"
(
d2
),
"=f"
(
d3
),
"=f"
(
d4
),
"=f"
(
d5
),
"=f"
(
d6
),
"=f"
(
d7
)
:
"r"
(
a0
),
"r"
(
a1
),
"r"
(
b0
),
"r"
(
b1
),
"f"
(
c0
),
"f"
(
c1
),
"f"
(
c2
),
"f"
(
c3
),
"f"
(
c4
),
"f"
(
c5
),
"f"
(
c6
),
"f"
(
c7
));
}
};
// FP16 inputs, FP32 accumulation - row.row (TransA=false, TransB=false)
template
<
>
struct
MmaSm70Impl
<
DataType
::
kFloat16
,
DataType
::
kFloat16
,
DataType
::
kFloat32
,
false
,
false
>
{
using
DRegisters
=
float
[
8
];
using
ARegisters
=
unsigned
[
2
];
using
BRegisters
=
unsigned
[
2
];
using
CRegisters
=
float
[
8
];
static
constexpr
bool
kSupported
=
true
;
static
TL_DEVICE
void
fma
(
float
&
d0
,
float
&
d1
,
float
&
d2
,
float
&
d3
,
float
&
d4
,
float
&
d5
,
float
&
d6
,
float
&
d7
,
unsigned
a0
,
unsigned
a1
,
unsigned
b0
,
unsigned
b1
,
float
c0
,
float
c1
,
float
c2
,
float
c3
,
float
c4
,
float
c5
,
float
c6
,
float
c7
)
{
asm
volatile
(
"mma.sync.aligned.m8n8k4.row.row.f32.f16.f16.f32 "
"{%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, "
"{%12,%13,%14,%15,%16,%17,%18,%19};
\n
"
:
"=f"
(
d0
),
"=f"
(
d1
),
"=f"
(
d2
),
"=f"
(
d3
),
"=f"
(
d4
),
"=f"
(
d5
),
"=f"
(
d6
),
"=f"
(
d7
)
:
"r"
(
a0
),
"r"
(
a1
),
"r"
(
b0
),
"r"
(
b1
),
"f"
(
c0
),
"f"
(
c1
),
"f"
(
c2
),
"f"
(
c3
),
"f"
(
c4
),
"f"
(
c5
),
"f"
(
c6
),
"f"
(
c7
));
}
};
// Helper to extract register types
template
<
class
Impl
>
struct
MmaSm70ImplTraits
{
using
DReg
=
std
::
remove_extent_t
<
typename
Impl
::
DRegisters
>
;
using
AReg
=
std
::
remove_extent_t
<
typename
Impl
::
ARegisters
>
;
using
BReg
=
std
::
remove_extent_t
<
typename
Impl
::
BRegisters
>
;
using
CReg
=
std
::
remove_extent_t
<
typename
Impl
::
CRegisters
>
;
static
constexpr
int
kDRegs
=
std
::
extent_v
<
typename
Impl
::
DRegisters
>
;
static
constexpr
int
kARegs
=
std
::
extent_v
<
typename
Impl
::
ARegisters
>
;
static
constexpr
int
kBRegs
=
std
::
extent_v
<
typename
Impl
::
BRegisters
>
;
static
constexpr
int
kCRegs
=
std
::
extent_v
<
typename
Impl
::
CRegisters
>
;
};
// Dispatcher for SM70 MMA operations
template
<
DataType
AType
,
DataType
BType
,
DataType
CType
,
int
M
,
int
N
,
int
K
,
bool
TransA
,
bool
TransB
>
struct
MmaSm70Dispatcher
{
using
CRegType
=
void
;
using
ARegType
=
void
;
using
BRegType
=
void
;
static
TL_DEVICE
void
exec
(
CRegType
*
,
const
ARegType
*
,
const
BRegType
*
,
const
CRegType
*
)
{
static_assert
(
always_false_v
<
std
::
integral_constant
<
int
,
M
>>
,
"tl::mma_sync_sm70: unsupported configuration. "
"SM70 only supports m16n16k4 with FP16 inputs and FP16/FP32 "
"accumulation."
);
}
};
// Helper to call fma with unpacked register arrays
template
<
class
Impl
,
size_t
...
DIdx
,
size_t
...
AIdx
,
size_t
...
BIdx
,
size_t
...
CIdx
>
TL_DEVICE
void
call_fma_impl_sm70
(
typename
MmaSm70ImplTraits
<
Impl
>::
DReg
*
d
,
const
typename
MmaSm70ImplTraits
<
Impl
>::
AReg
*
a
,
const
typename
MmaSm70ImplTraits
<
Impl
>::
BReg
*
b
,
const
typename
MmaSm70ImplTraits
<
Impl
>::
CReg
*
c
,
std
::
index_sequence
<
DIdx
...
>
,
std
::
index_sequence
<
AIdx
...
>
,
std
::
index_sequence
<
BIdx
...
>
,
std
::
index_sequence
<
CIdx
...
>
)
{
Impl
::
fma
(
d
[
DIdx
]...,
a
[
AIdx
]...,
b
[
BIdx
]...,
c
[
CIdx
]...);
}
template
<
class
Impl
>
TL_DEVICE
void
call_fma_sm70
(
typename
MmaSm70ImplTraits
<
Impl
>::
DReg
*
d
,
const
typename
MmaSm70ImplTraits
<
Impl
>::
AReg
*
a
,
const
typename
MmaSm70ImplTraits
<
Impl
>::
BReg
*
b
,
const
typename
MmaSm70ImplTraits
<
Impl
>::
CReg
*
c
)
{
call_fma_impl_sm70
<
Impl
>
(
d
,
a
,
b
,
c
,
std
::
make_index_sequence
<
MmaSm70ImplTraits
<
Impl
>::
kDRegs
>
{},
std
::
make_index_sequence
<
MmaSm70ImplTraits
<
Impl
>::
kARegs
>
{},
std
::
make_index_sequence
<
MmaSm70ImplTraits
<
Impl
>::
kBRegs
>
{},
std
::
make_index_sequence
<
MmaSm70ImplTraits
<
Impl
>::
kCRegs
>
{});
}
// Define dispatchers for all supported SM70 configurations
// Note: m8n8k4 instruction computes m16n16k4 at warp level
#define TL_DEFINE_MMA_SM70_DISPATCHER(ATypeEnum, BTypeEnum, CTypeEnum, \
TransAValue, TransBValue) \
template <> \
struct MmaSm70Dispatcher<DataType::ATypeEnum, DataType::BTypeEnum, \
DataType::CTypeEnum, 16, 16, 4, TransAValue, \
TransBValue> { \
using Impl = MmaSm70Impl<DataType::ATypeEnum, DataType::BTypeEnum, \
DataType::CTypeEnum, TransAValue, TransBValue>; \
using Traits = MmaSm70ImplTraits<Impl>; \
using CRegType = typename Traits::DReg; \
using ARegType = typename Traits::AReg; \
using BRegType = typename Traits::BReg; \
static_assert( \
std::is_same_v<typename Traits::DReg, typename Traits::CReg>, \
"tl::mma_sync_sm70 requires matching accumulator/output regs"); \
static TL_DEVICE void exec(CRegType *d, const ARegType *a, \
const BRegType *b, const CRegType *c) { \
call_fma_sm70<Impl>(d, a, b, c); \
} \
};
// FP16 inputs with FP16 accumulation (all layout combinations)
TL_DEFINE_MMA_SM70_DISPATCHER
(
kFloat16
,
kFloat16
,
kFloat16
,
true
,
true
)
TL_DEFINE_MMA_SM70_DISPATCHER
(
kFloat16
,
kFloat16
,
kFloat16
,
true
,
false
)
TL_DEFINE_MMA_SM70_DISPATCHER
(
kFloat16
,
kFloat16
,
kFloat16
,
false
,
true
)
TL_DEFINE_MMA_SM70_DISPATCHER
(
kFloat16
,
kFloat16
,
kFloat16
,
false
,
false
)
// FP16 inputs with FP32 accumulation (all layout combinations)
TL_DEFINE_MMA_SM70_DISPATCHER
(
kFloat16
,
kFloat16
,
kFloat32
,
true
,
true
)
TL_DEFINE_MMA_SM70_DISPATCHER
(
kFloat16
,
kFloat16
,
kFloat32
,
true
,
false
)
TL_DEFINE_MMA_SM70_DISPATCHER
(
kFloat16
,
kFloat16
,
kFloat32
,
false
,
true
)
TL_DEFINE_MMA_SM70_DISPATCHER
(
kFloat16
,
kFloat16
,
kFloat32
,
false
,
false
)
#undef TL_DEFINE_MMA_SM70_DISPATCHER
}
// namespace detail
/// SM70 MMA synchronous instruction wrapper
/// Supports m16n16k4 shape (m8n8k4 instruction at warp level) with FP16 inputs
/// and FP16/FP32 accumulation
///
/// @tparam AType Input A data type (kFloat16)
/// @tparam BType Input B data type (kFloat16)
/// @tparam CType Accumulator/output data type (kFloat16 or kFloat32)
/// @tparam M Matrix M dimension (16)
/// @tparam N Matrix N dimension (16)
/// @tparam K Matrix K dimension (4)
/// @tparam TransA Whether A is transposed (false=row-major, true=col-major)
/// @tparam TransB Whether B is transposed (false=row-major, true=col-major)
template
<
DataType
AType
,
DataType
BType
,
DataType
CType
,
int
M
,
int
N
,
int
K
,
bool
TransA
,
bool
TransB
>
TL_DEVICE
void
mma_sync_sm70
(
typename
detail
::
MmaSm70Dispatcher
<
AType
,
BType
,
CType
,
M
,
N
,
K
,
TransA
,
TransB
>::
CRegType
*
c
,
const
typename
detail
::
MmaSm70Dispatcher
<
AType
,
BType
,
CType
,
M
,
N
,
K
,
TransA
,
TransB
>::
ARegType
*
a
,
const
typename
detail
::
MmaSm70Dispatcher
<
AType
,
BType
,
CType
,
M
,
N
,
K
,
TransA
,
TransB
>::
BRegType
*
b
)
{
using
Dispatcher
=
detail
::
MmaSm70Dispatcher
<
AType
,
BType
,
CType
,
M
,
N
,
K
,
TransA
,
TransB
>
;
static_assert
(
!
std
::
is_void_v
<
typename
Dispatcher
::
CRegType
>
,
"tl::mma_sync_sm70: unsupported configuration. "
"SM70 only supports m16n16k4 with FP16 inputs."
);
Dispatcher
::
exec
(
c
,
a
,
b
,
c
);
}
}
// namespace tl
src/tl_templates/cuda/instruction/tcgen05mma.h
0 → 100644
View file @
bbbf4207
#pragma once
#include "../common.h"
#include <cute/arch/cluster_sm90.hpp>
namespace
tl
{
#ifndef TL_ALWAYS_FALSE_V_DEFINED
#define TL_ALWAYS_FALSE_V_DEFINED
template
<
class
>
inline
constexpr
bool
always_false_v
=
false
;
#endif
// Generic declaration: unsupported by default
template
<
DataType
C_type
>
TL_DEVICE
void
tcgen05mma_ss
(
uint64_t
const
&
/*desc_a*/
,
uint64_t
const
&
/*desc_b*/
,
uint32_t
const
&
/*tmem_c*/
,
uint32_t
const
&
/*scalec*/
,
uint32_t
const
&
/*desc_val*/
,
int
const
&
/*mask0*/
,
int
const
&
/*mask1*/
,
int
const
&
/*mask2*/
,
int
const
&
/*mask3*/
)
{
static_assert
(
always_false_v
<
std
::
integral_constant
<
int
,
static_cast
<
int
>
(
C_type
)
>>
,
"tl::tcgen05mma_ss: unsupported accumulator type"
);
}
// TS variants: A from TMEM, B from SMEM (desc)
// Generic declaration: unsupported by default
template
<
DataType
C_type
>
TL_DEVICE
void
tcgen05mma_ts
(
uint32_t
const
&
/*tmem_a*/
,
uint64_t
const
&
/*desc_b*/
,
uint32_t
const
&
/*tmem_c*/
,
uint32_t
const
&
/*scalec*/
,
uint32_t
const
&
/*desc_val*/
,
int
const
&
/*mask0*/
,
int
const
&
/*mask1*/
,
int
const
&
/*mask2*/
,
int
const
&
/*mask3*/
)
{
static_assert
(
always_false_v
<
std
::
integral_constant
<
int
,
static_cast
<
int
>
(
C_type
)
>>
,
"tl::tcgen05mma_ts: unsupported accumulator type"
);
}
// F16/BF16 instruction kind (maps to kind::f16)
template
<
>
TL_DEVICE
void
tcgen05mma_ts
<
DataType
::
kFloat16
>
(
uint32_t
const
&
tmem_a
,
uint64_t
const
&
desc_b
,
uint32_t
const
&
tmem_c
,
uint32_t
const
&
scalec
,
uint32_t
const
&
desc_val
,
int
const
&
mask0
,
int
const
&
mask1
,
int
const
&
mask2
,
int
const
&
mask3
)
{
if
(
cute
::
elect_one_sync
())
{
asm
volatile
(
"{
\n\t
"
".reg .pred p;
\n\t
"
"setp.ne.b32 p, %4, 0;
\n\t
"
"tcgen05.mma.cta_group::1.kind::f16 [%0], [%1], %2, %3, {%5, "
"%6, %7, %8}, p;
\n\t
"
"}
\n
"
:
:
"r"
(
tmem_c
),
"r"
(
tmem_a
),
"l"
(
desc_b
),
"r"
(
desc_val
),
"r"
(
scalec
),
"r"
(
mask0
),
"r"
(
mask1
),
"r"
(
mask2
),
"r"
(
mask3
));
}
}
// BF16 maps to the same f16-kind instruction
template
<
>
TL_DEVICE
void
tcgen05mma_ts
<
DataType
::
kBFloat16
>
(
uint32_t
const
&
tmem_a
,
uint64_t
const
&
desc_b
,
uint32_t
const
&
tmem_c
,
uint32_t
const
&
scalec
,
uint32_t
const
&
desc_val
,
int
const
&
mask0
,
int
const
&
mask1
,
int
const
&
mask2
,
int
const
&
mask3
)
{
tcgen05mma_ts
<
DataType
::
kFloat16
>
(
tmem_a
,
desc_b
,
tmem_c
,
scalec
,
desc_val
,
mask0
,
mask1
,
mask2
,
mask3
);
}
// TF32 instruction kind
template
<
>
TL_DEVICE
void
tcgen05mma_ts
<
DataType
::
kTensorFloat32
>
(
uint32_t
const
&
tmem_a
,
uint64_t
const
&
desc_b
,
uint32_t
const
&
tmem_c
,
uint32_t
const
&
scalec
,
uint32_t
const
&
desc_val
,
int
const
&
mask0
,
int
const
&
mask1
,
int
const
&
mask2
,
int
const
&
mask3
)
{
if
(
cute
::
elect_one_sync
())
{
asm
volatile
(
"{
\n\t
"
".reg .pred p;
\n\t
"
"setp.ne.b32 p, %4, 0;
\n\t
"
"tcgen05.mma.cta_group::1.kind::tf32 [%0], [%1], %2, %3, {%5, "
"%6, %7, %8}, p;
\n\t
"
"}
\n
"
:
:
"r"
(
tmem_c
),
"r"
(
tmem_a
),
"l"
(
desc_b
),
"r"
(
desc_val
),
"r"
(
scalec
),
"r"
(
mask0
),
"r"
(
mask1
),
"r"
(
mask2
),
"r"
(
mask3
));
}
}
// INT8 instruction kind
template
<
>
TL_DEVICE
void
tcgen05mma_ts
<
DataType
::
kInt8
>
(
uint32_t
const
&
tmem_a
,
uint64_t
const
&
desc_b
,
uint32_t
const
&
tmem_c
,
uint32_t
const
&
scalec
,
uint32_t
const
&
desc_val
,
int
const
&
mask0
,
int
const
&
mask1
,
int
const
&
mask2
,
int
const
&
mask3
)
{
if
(
cute
::
elect_one_sync
())
{
asm
volatile
(
"{
\n\t
"
".reg .pred p;
\n\t
"
"setp.ne.b32 p, %4, 0;
\n\t
"
"tcgen05.mma.cta_group::1.kind::i8 [%0], [%1], %2, %3, {%5, "
"%6, %7, %8}, p;
\n\t
"
"}
\n
"
:
:
"r"
(
tmem_c
),
"r"
(
tmem_a
),
"l"
(
desc_b
),
"r"
(
desc_val
),
"r"
(
scalec
),
"r"
(
mask0
),
"r"
(
mask1
),
"r"
(
mask2
),
"r"
(
mask3
));
}
}
// FP8 family instruction kind (maps to f8f6f4)
template
<
>
TL_DEVICE
void
tcgen05mma_ts
<
DataType
::
kFloat8_e4m3
>
(
uint32_t
const
&
tmem_a
,
uint64_t
const
&
desc_b
,
uint32_t
const
&
tmem_c
,
uint32_t
const
&
scalec
,
uint32_t
const
&
desc_val
,
int
const
&
mask0
,
int
const
&
mask1
,
int
const
&
mask2
,
int
const
&
mask3
)
{
if
(
cute
::
elect_one_sync
())
{
asm
volatile
(
"{
\n\t
"
".reg .pred p;
\n\t
"
"setp.ne.b32 p, %4, 0;
\n\t
"
"tcgen05.mma.cta_group::1.kind::f8f6f4 [%0], [%1], %2, %3, "
"{%5, %6, %7, %8}, p;
\n\t
"
"}
\n
"
:
:
"r"
(
tmem_c
),
"r"
(
tmem_a
),
"l"
(
desc_b
),
"r"
(
desc_val
),
"r"
(
scalec
),
"r"
(
mask0
),
"r"
(
mask1
),
"r"
(
mask2
),
"r"
(
mask3
));
}
}
template
<
>
TL_DEVICE
void
tcgen05mma_ts
<
DataType
::
kFloat8_e5m2
>
(
uint32_t
const
&
tmem_a
,
uint64_t
const
&
desc_b
,
uint32_t
const
&
tmem_c
,
uint32_t
const
&
scalec
,
uint32_t
const
&
desc_val
,
int
const
&
mask0
,
int
const
&
mask1
,
int
const
&
mask2
,
int
const
&
mask3
)
{
tcgen05mma_ts
<
DataType
::
kFloat8_e4m3
>
(
tmem_a
,
desc_b
,
tmem_c
,
scalec
,
desc_val
,
mask0
,
mask1
,
mask2
,
mask3
);
}
// F16/BF16 instruction kind (maps to kind::f16)
template
<
>
TL_DEVICE
void
tcgen05mma_ss
<
DataType
::
kFloat16
>
(
uint64_t
const
&
desc_a
,
uint64_t
const
&
desc_b
,
uint32_t
const
&
tmem_c
,
uint32_t
const
&
scalec
,
uint32_t
const
&
desc_val
,
int
const
&
mask0
,
int
const
&
mask1
,
int
const
&
mask2
,
int
const
&
mask3
)
{
// idescE upper 32 bits carry the instruction descriptor; lower 32 ignored for
// SS Load TMEM base from shared memory slot handled by caller
if
(
cute
::
elect_one_sync
())
{
asm
volatile
(
"{
\n\t
"
".reg .pred p;
\n\t
"
"setp.ne.b32 p, %4, 0;
\n\t
"
"tcgen05.mma.cta_group::1.kind::f16 [%0], %1, %2, %3, {%5, "
"%6, %7, %8}, p;
\n\t
"
"}
\n
"
:
:
"r"
(
tmem_c
),
"l"
(
desc_a
),
"l"
(
desc_b
),
"r"
(
desc_val
),
"r"
(
scalec
),
"r"
(
mask0
),
"r"
(
mask1
),
"r"
(
mask2
),
"r"
(
mask3
));
}
}
// BF16 maps to the same f16-kind instruction
template
<
>
TL_DEVICE
void
tcgen05mma_ss
<
DataType
::
kBFloat16
>
(
uint64_t
const
&
desc_a
,
uint64_t
const
&
desc_b
,
uint32_t
const
&
tmem_c
,
uint32_t
const
&
scalec
,
uint32_t
const
&
desc_val
,
int
const
&
mask0
,
int
const
&
mask1
,
int
const
&
mask2
,
int
const
&
mask3
)
{
tcgen05mma_ss
<
DataType
::
kFloat16
>
(
desc_a
,
desc_b
,
tmem_c
,
scalec
,
desc_val
,
mask0
,
mask1
,
mask2
,
mask3
);
}
// TF32 instruction kind
template
<
>
TL_DEVICE
void
tcgen05mma_ss
<
DataType
::
kTensorFloat32
>
(
uint64_t
const
&
desc_a
,
uint64_t
const
&
desc_b
,
uint32_t
const
&
tmem_c
,
uint32_t
const
&
scalec
,
uint32_t
const
&
desc_val
,
int
const
&
mask0
,
int
const
&
mask1
,
int
const
&
mask2
,
int
const
&
mask3
)
{
if
(
cute
::
elect_one_sync
())
{
asm
volatile
(
"{
\n\t
"
".reg .pred p;
\n\t
"
"setp.ne.b32 p, %4, 0;
\n\t
"
"tcgen05.mma.cta_group::1.kind::tf32 [%0], %1, %2, %3, {%5, "
"%6, %7, %8}, p;
\n\t
"
"}
\n
"
:
:
"r"
(
tmem_c
),
"l"
(
desc_a
),
"l"
(
desc_b
),
"r"
(
desc_val
),
"r"
(
scalec
),
"r"
(
mask0
),
"r"
(
mask1
),
"r"
(
mask2
),
"r"
(
mask3
));
}
}
// INT8 instruction kind
template
<
>
TL_DEVICE
void
tcgen05mma_ss
<
DataType
::
kInt8
>
(
uint64_t
const
&
desc_a
,
uint64_t
const
&
desc_b
,
uint32_t
const
&
tmem_c
,
uint32_t
const
&
scalec
,
uint32_t
const
&
desc_val
,
int
const
&
mask0
,
int
const
&
mask1
,
int
const
&
mask2
,
int
const
&
mask3
)
{
if
(
cute
::
elect_one_sync
())
{
asm
volatile
(
"{
\n\t
"
".reg .pred p;
\n\t
"
"setp.ne.b32 p, %4, 0;
\n\t
"
"tcgen05.mma.cta_group::1.kind::i8 [%0], %1, %2, %3, {%5, %6, "
"%7, %8}, p;
\n\t
"
"}
\n
"
:
:
"r"
(
tmem_c
),
"l"
(
desc_a
),
"l"
(
desc_b
),
"r"
(
desc_val
),
"r"
(
scalec
),
"r"
(
mask0
),
"r"
(
mask1
),
"r"
(
mask2
),
"r"
(
mask3
));
}
}
// FP8 family instruction kind (maps to f8f6f4)
template
<
>
TL_DEVICE
void
tcgen05mma_ss
<
DataType
::
kFloat8_e4m3
>
(
uint64_t
const
&
desc_a
,
uint64_t
const
&
desc_b
,
uint32_t
const
&
tmem_c
,
uint32_t
const
&
scalec
,
uint32_t
const
&
desc_val
,
int
const
&
mask0
,
int
const
&
mask1
,
int
const
&
mask2
,
int
const
&
mask3
)
{
if
(
cute
::
elect_one_sync
())
{
asm
volatile
(
"{
\n\t
"
".reg .pred p;
\n\t
"
"setp.ne.b32 p, %4, 0;
\n\t
"
"tcgen05.mma.cta_group::1.kind::f8f6f4 [%0], %1, %2, %3, {%5, "
"%6, %7, %8}, p;
\n\t
"
"}
\n
"
:
:
"r"
(
tmem_c
),
"l"
(
desc_a
),
"l"
(
desc_b
),
"r"
(
desc_val
),
"r"
(
scalec
),
"r"
(
mask0
),
"r"
(
mask1
),
"r"
(
mask2
),
"r"
(
mask3
));
}
}
template
<
>
TL_DEVICE
void
tcgen05mma_ss
<
DataType
::
kFloat8_e5m2
>
(
uint64_t
const
&
desc_a
,
uint64_t
const
&
desc_b
,
uint32_t
const
&
tmem_c
,
uint32_t
const
&
scalec
,
uint32_t
const
&
desc_val
,
int
const
&
mask0
,
int
const
&
mask1
,
int
const
&
mask2
,
int
const
&
mask3
)
{
tcgen05mma_ss
<
DataType
::
kFloat8_e4m3
>
(
desc_a
,
desc_b
,
tmem_c
,
scalec
,
desc_val
,
mask0
,
mask1
,
mask2
,
mask3
);
}
// WS variants: tcgen05.mma.ws.cta_group::1.kind::xxx
// Generic declaration falls back to static assert
template
<
DataType
C_type
>
TL_DEVICE
void
tcgen05mma_ws_ss
(
uint64_t
const
&
/*desc_a*/
,
uint64_t
const
&
/*desc_b*/
,
uint32_t
const
&
/*tmem_c*/
,
uint32_t
const
&
/*scalec*/
,
uint32_t
const
&
/*desc_val*/
,
int
const
&
/*mask0*/
,
int
const
&
/*mask1*/
,
int
const
&
/*mask2*/
,
int
const
&
/*mask3*/
)
{
static_assert
(
always_false_v
<
std
::
integral_constant
<
int
,
static_cast
<
int
>
(
C_type
)
>>
,
"tl::tcgen05mma_ws_ss: unsupported accumulator type"
);
}
// F16/BF16 ws
template
<
>
TL_DEVICE
void
tcgen05mma_ws_ss
<
DataType
::
kFloat16
>
(
uint64_t
const
&
desc_a
,
uint64_t
const
&
desc_b
,
uint32_t
const
&
tmem_c
,
uint32_t
const
&
scalec
,
uint32_t
const
&
desc_val
,
int
const
&
mask0
,
int
const
&
mask1
,
int
const
&
mask2
,
int
const
&
mask3
)
{
if
(
cute
::
elect_one_sync
())
{
asm
volatile
(
"{
\n\t
"
".reg .pred p;
\n\t
"
"setp.ne.b32 p, %4, 0;
\n\t
"
"tcgen05.mma.ws.cta_group::1.kind::f16 [%0], %1, %2, %3, p, 0;
\n\t
"
"}
\n
"
:
:
"r"
(
tmem_c
),
"l"
(
desc_a
),
"l"
(
desc_b
),
"r"
(
desc_val
),
"r"
(
scalec
));
}
}
template
<
>
TL_DEVICE
void
tcgen05mma_ws_ss
<
DataType
::
kBFloat16
>
(
uint64_t
const
&
desc_a
,
uint64_t
const
&
desc_b
,
uint32_t
const
&
tmem_c
,
uint32_t
const
&
scalec
,
uint32_t
const
&
desc_val
,
int
const
&
mask0
,
int
const
&
mask1
,
int
const
&
mask2
,
int
const
&
mask3
)
{
tcgen05mma_ws_ss
<
DataType
::
kFloat16
>
(
desc_a
,
desc_b
,
tmem_c
,
scalec
,
desc_val
,
mask0
,
mask1
,
mask2
,
mask3
);
}
// TF32 ws
template
<
>
TL_DEVICE
void
tcgen05mma_ws_ss
<
DataType
::
kTensorFloat32
>
(
uint64_t
const
&
desc_a
,
uint64_t
const
&
desc_b
,
uint32_t
const
&
tmem_c
,
uint32_t
const
&
scalec
,
uint32_t
const
&
desc_val
,
int
const
&
mask0
,
int
const
&
mask1
,
int
const
&
mask2
,
int
const
&
mask3
)
{
if
(
cute
::
elect_one_sync
())
{
asm
volatile
(
"{
\n\t
"
".reg .pred p;
\n\t
"
"setp.ne.b32 p, %4, 0;
\n\t
"
"tcgen05.mma.ws.cta_group::1.kind::tf32 [%0], %1, %2, %3, p, 0;
\n\t
"
"}
\n
"
:
:
"r"
(
tmem_c
),
"l"
(
desc_a
),
"l"
(
desc_b
),
"r"
(
desc_val
),
"r"
(
scalec
));
}
}
// INT8 ws
template
<
>
TL_DEVICE
void
tcgen05mma_ws_ss
<
DataType
::
kInt8
>
(
uint64_t
const
&
desc_a
,
uint64_t
const
&
desc_b
,
uint32_t
const
&
tmem_c
,
uint32_t
const
&
scalec
,
uint32_t
const
&
desc_val
,
int
const
&
mask0
,
int
const
&
mask1
,
int
const
&
mask2
,
int
const
&
mask3
)
{
if
(
cute
::
elect_one_sync
())
{
asm
volatile
(
"{
\n\t
"
".reg .pred p;
\n\t
"
"setp.ne.b32 p, %4, 0;
\n\t
"
"tcgen05.mma.ws.cta_group::1.kind::i8 [%0], %1, %2, %3, p, 0;
\n\t
"
"}
\n
"
:
:
"r"
(
tmem_c
),
"l"
(
desc_a
),
"l"
(
desc_b
),
"r"
(
desc_val
),
"r"
(
scalec
));
}
}
// FP8 ws (maps to f8f6f4)
template
<
>
TL_DEVICE
void
tcgen05mma_ws_ss
<
DataType
::
kFloat8_e4m3
>
(
uint64_t
const
&
desc_a
,
uint64_t
const
&
desc_b
,
uint32_t
const
&
tmem_c
,
uint32_t
const
&
scalec
,
uint32_t
const
&
desc_val
,
int
const
&
mask0
,
int
const
&
mask1
,
int
const
&
mask2
,
int
const
&
mask3
)
{
if
(
cute
::
elect_one_sync
())
{
asm
volatile
(
"{
\n\t
"
".reg .pred p;
\n\t
"
"setp.ne.b32 p, %4, 0;
\n\t
"
"tcgen05.mma.ws.cta_group::1.kind::f8f6f4 [%0], %1, %2, %3, p, 0;
\n\t
"
"}
\n
"
:
:
"r"
(
tmem_c
),
"l"
(
desc_a
),
"l"
(
desc_b
),
"r"
(
desc_val
),
"r"
(
scalec
));
}
}
template
<
>
TL_DEVICE
void
tcgen05mma_ws_ss
<
DataType
::
kFloat8_e5m2
>
(
uint64_t
const
&
desc_a
,
uint64_t
const
&
desc_b
,
uint32_t
const
&
tmem_c
,
uint32_t
const
&
scalec
,
uint32_t
const
&
desc_val
,
int
const
&
mask0
,
int
const
&
mask1
,
int
const
&
mask2
,
int
const
&
mask3
)
{
tcgen05mma_ws_ss
<
DataType
::
kFloat8_e4m3
>
(
desc_a
,
desc_b
,
tmem_c
,
scalec
,
desc_val
,
mask0
,
mask1
,
mask2
,
mask3
);
}
}
// namespace tl
src/tl_templates/cuda/instruction/wgmma.h
View file @
bbbf4207
#pragma once
#pragma once
#include "../common.h"
#include "../common.h"
#include "cute/arch/mma_sm90_gmma.hpp"
#include <cute/arch/mma_sm90_gmma.hpp>
#include <cute/arch/mma_sm90_gmma_ext.hpp>
#include <type_traits>
#include <utility>
namespace
tl
{
namespace
tl
{
#ifndef TL_ALWAYS_FALSE_V_DEFINED
#define TL_ALWAYS_FALSE_V_DEFINED
template
<
class
>
inline
constexpr
bool
always_false_v
=
false
;
template
<
class
>
inline
constexpr
bool
always_false_v
=
false
;
#endif
// 主类模板 - 移除默认参数,因为特化不能有默认参数
namespace
detail
{
template
<
DataType
A_type
,
DataType
B_type
,
DataType
C_type
,
int
M
,
int
N
,
int
K
,
bool
tnspA
,
bool
tnspB
,
int
scaleA
,
int
scaleB
>
struct
WgmmaSSImpl
{
TL_DEVICE
static
void
execute
(
uint64_t
desc_a
,
uint64_t
desc_b
,
uint32_t
*
c
,
bool
scale_out
)
{
printf
(
"DEBUG: WgmmaSSImpl fallback - A_type=%d (kFloat16=%d), B_type=%d, "
"C_type=%d, M=%d, N=%d, K=%d, tnspA=%d, tnspB=%d, scaleA=%d, "
"scaleB=%d
\n
"
,
(
int
)
A_type
,
(
int
)
DataType
::
kFloat16
,
(
int
)
B_type
,
(
int
)
C_type
,
M
,
N
,
K
,
(
int
)
tnspA
,
(
int
)
tnspB
,
scaleA
,
scaleB
);
// 暂时注释掉 static_assert 来看调试输出
// static_assert(always_false_v<decltype(c)>,
// "wgmma_ss: No specialization available for given template
// parameters!");
};
};
// ================================= F16 x F16 -> F16
// =================================
// M64N8K16 F16
template
<
bool
tnspA
,
bool
tnspB
,
int
scaleA
,
int
scaleB
>
struct
WgmmaSSImpl
<
DataType
::
kFloat16
,
DataType
::
kFloat16
,
DataType
::
kFloat16
,
64
,
8
,
16
,
tnspA
,
tnspB
,
scaleA
,
scaleB
>
{
TL_DEVICE
static
void
execute
(
uint64_t
desc_a
,
uint64_t
desc_b
,
uint32_t
*
c
,
bool
scale_out
)
{
asm
volatile
(
"{
\n
"
".reg .pred p;
\n
"
"setp.ne.b32 p, %4, 0;
\n
"
"wgmma.mma_async.sync.aligned.m64n8k16.f16.f16.f16 "
"{%0, %1}, %2, %3, p, %5, %6, %7, %8;
\n
"
"}
\n
"
:
"+r"
(
c
[
0
]),
"+r"
(
c
[
1
])
:
"l"
(
desc_a
),
"l"
(
desc_b
),
"r"
(
int32_t
(
scale_out
)),
"n"
(
int32_t
(
scaleA
)),
"n"
(
int32_t
(
scaleB
)),
"n"
(
int32_t
(
tnspA
)),
"n"
(
int32_t
(
tnspB
)));
}
};
// M64N16K16 F16
template
<
bool
tnspA
,
bool
tnspB
,
int
scaleA
,
int
scaleB
>
struct
WgmmaSSImpl
<
DataType
::
kFloat16
,
DataType
::
kFloat16
,
DataType
::
kFloat16
,
64
,
16
,
16
,
tnspA
,
tnspB
,
scaleA
,
scaleB
>
{
TL_DEVICE
static
void
execute
(
uint64_t
desc_a
,
uint64_t
desc_b
,
uint32_t
*
c
,
bool
scale_out
)
{
asm
volatile
(
"{
\n
"
".reg .pred p;
\n
"
"setp.ne.b32 p, %6, 0;
\n
"
"wgmma.mma_async.sync.aligned.m64n16k16.f16.f16.f16 "
"{%0, %1, %2, %3}, %4, %5, p, %7, %8, %9, %10;
\n
"
"}
\n
"
:
"+r"
(
c
[
0
]),
"+r"
(
c
[
1
]),
"+r"
(
c
[
2
]),
"+r"
(
c
[
3
])
:
"l"
(
desc_a
),
"l"
(
desc_b
),
"r"
(
int32_t
(
scale_out
)),
"n"
(
int32_t
(
scaleA
)),
"n"
(
int32_t
(
scaleB
)),
"n"
(
int32_t
(
tnspA
)),
"n"
(
int32_t
(
tnspB
)));
}
};
// M64N32K16 F16
template
<
bool
IsMnMajor
>
struct
MajorValue
{
template
<
bool
tnspA
,
bool
tnspB
,
int
scaleA
,
int
scaleB
>
static
constexpr
auto
value
=
struct
WgmmaSSImpl
<
DataType
::
kFloat16
,
DataType
::
kFloat16
,
DataType
::
kFloat16
,
IsMnMajor
?
cute
::
SM90
::
GMMA
::
Major
::
MN
:
cute
::
SM90
::
GMMA
::
Major
::
K
;
64
,
32
,
16
,
tnspA
,
tnspB
,
scaleA
,
scaleB
>
{
TL_DEVICE
static
void
execute
(
uint64_t
desc_a
,
uint64_t
desc_b
,
uint32_t
*
c
,
bool
scale_out
)
{
asm
volatile
(
"{
\n
"
".reg .pred p;
\n
"
"setp.ne.b32 p, %10, 0;
\n
"
"wgmma.mma_async.sync.aligned.m64n32k16.f16.f16.f16 "
"{%0, %1, %2, %3, %4, %5, %6, %7}, %8, %9, p, %11, %12, %13, %14;
\n
"
"}
\n
"
:
"+r"
(
c
[
0
]),
"+r"
(
c
[
1
]),
"+r"
(
c
[
2
]),
"+r"
(
c
[
3
]),
"+r"
(
c
[
4
]),
"+r"
(
c
[
5
]),
"+r"
(
c
[
6
]),
"+r"
(
c
[
7
])
:
"l"
(
desc_a
),
"l"
(
desc_b
),
"r"
(
int32_t
(
scale_out
)),
"n"
(
int32_t
(
scaleA
)),
"n"
(
int32_t
(
scaleB
)),
"n"
(
int32_t
(
tnspA
)),
"n"
(
int32_t
(
tnspB
)));
}
};
};
// M64N64K16 F16
template
<
int
Scale
>
struct
ScaleInValue
{
template
<
bool
tnspA
,
bool
tnspB
,
int
scaleA
,
int
scaleB
>
static_assert
(
Scale
==
1
||
Scale
==
-
1
,
struct
WgmmaSSImpl
<
DataType
::
kFloat16
,
DataType
::
kFloat16
,
DataType
::
kFloat16
,
"tl::wgmma requires scale factors of +1 or -1."
);
64
,
64
,
16
,
tnspA
,
tnspB
,
scaleA
,
scaleB
>
{
static
constexpr
auto
value
=
Scale
==
1
?
cute
::
SM90
::
GMMA
::
ScaleIn
::
One
TL_DEVICE
static
void
execute
(
uint64_t
desc_a
,
uint64_t
desc_b
,
uint32_t
*
c
,
:
cute
::
SM90
::
GMMA
::
ScaleIn
::
Neg
;
bool
scale_out
)
{
asm
volatile
(
"{
\n
"
".reg .pred p;
\n
"
"setp.ne.b32 p, %18, 0;
\n
"
"wgmma.mma_async.sync.aligned.m64n64k16.f16.f16.f16 "
"{%0, %1, %2, %3, %4, %5, %6, %7, "
"%8, %9, %10, %11, %12, %13, %14, %15},"
" %16, %17, p, %19, %20, %21, %22;
\n
"
"}
\n
"
:
"+r"
(
c
[
0
]),
"+r"
(
c
[
1
]),
"+r"
(
c
[
2
]),
"+r"
(
c
[
3
]),
"+r"
(
c
[
4
]),
"+r"
(
c
[
5
]),
"+r"
(
c
[
6
]),
"+r"
(
c
[
7
]),
"+r"
(
c
[
8
]),
"+r"
(
c
[
9
]),
"+r"
(
c
[
10
]),
"+r"
(
c
[
11
]),
"+r"
(
c
[
12
]),
"+r"
(
c
[
13
]),
"+r"
(
c
[
14
]),
"+r"
(
c
[
15
])
:
"l"
(
desc_a
),
"l"
(
desc_b
),
"r"
(
int32_t
(
scale_out
)),
"n"
(
int32_t
(
scaleA
)),
"n"
(
int32_t
(
scaleB
)),
"n"
(
int32_t
(
tnspA
)),
"n"
(
int32_t
(
tnspB
)));
}
};
};
// M64N96K16 F16
template
<
int
Scale
>
template
<
bool
tnspA
,
bool
tnspB
,
int
scaleA
,
int
scaleB
>
inline
constexpr
bool
IsValidScale
=
(
Scale
==
1
||
Scale
==
-
1
);
struct
WgmmaSSImpl
<
DataType
::
kFloat16
,
DataType
::
kFloat16
,
DataType
::
kFloat16
,
64
,
96
,
16
,
tnspA
,
tnspB
,
scaleA
,
scaleB
>
{
TL_DEVICE
static
void
execute
(
uint64_t
desc_a
,
uint64_t
desc_b
,
uint32_t
*
c
,
bool
scale_out
)
{
asm
volatile
(
"{
\n
"
".reg .pred p;
\n
"
"setp.ne.b32 p, %26, 0;
\n
"
"wgmma.mma_async.sync.aligned.m64n96k16.f16.f16.f16 "
"{%0, %1, %2, %3, %4, %5, %6, %7, "
"%8, %9, %10, %11, %12, %13, %14, %15, "
"%16, %17, %18, %19, %20, %21, %22, %23}, "
"%24, %25, p, %27, %28, %29, %30;
\n
"
"}
\n
"
:
"+r"
(
c
[
0
]),
"+r"
(
c
[
1
]),
"+r"
(
c
[
2
]),
"+r"
(
c
[
3
]),
"+r"
(
c
[
4
]),
"+r"
(
c
[
5
]),
"+r"
(
c
[
6
]),
"+r"
(
c
[
7
]),
"+r"
(
c
[
8
]),
"+r"
(
c
[
9
]),
"+r"
(
c
[
10
]),
"+r"
(
c
[
11
]),
"+r"
(
c
[
12
]),
"+r"
(
c
[
13
]),
"+r"
(
c
[
14
]),
"+r"
(
c
[
15
]),
"+r"
(
c
[
16
]),
"+r"
(
c
[
17
]),
"+r"
(
c
[
18
]),
"+r"
(
c
[
19
]),
"+r"
(
c
[
20
]),
"+r"
(
c
[
21
]),
"+r"
(
c
[
22
]),
"+r"
(
c
[
23
])
:
"l"
(
desc_a
),
"l"
(
desc_b
),
"r"
(
int32_t
(
scale_out
)),
"n"
(
int32_t
(
scaleA
)),
"n"
(
int32_t
(
scaleB
)),
"n"
(
int32_t
(
tnspA
)),
"n"
(
int32_t
(
tnspB
)));
}
};
// M64N128K16 F16
template
<
class
Impl
>
struct
CallWgmmaSS
{
template
<
bool
tnspA
,
bool
tnspB
,
int
scaleA
,
int
scaleB
>
using
CReg
=
std
::
remove_extent_t
<
typename
Impl
::
CRegisters
>
;
struct
WgmmaSSImpl
<
DataType
::
kFloat16
,
DataType
::
kFloat16
,
DataType
::
kFloat16
,
static
constexpr
int
kCRegs
=
std
::
extent_v
<
typename
Impl
::
CRegisters
>
;
64
,
128
,
16
,
tnspA
,
tnspB
,
scaleA
,
scaleB
>
{
static_assert
(
sizeof
(
CReg
)
==
sizeof
(
uint32_t
),
TL_DEVICE
static
void
execute
(
uint64_t
desc_a
,
uint64_t
desc_b
,
uint32_t
*
c
,
"tl::wgmma_ss expects 32-bit accumulator registers."
);
bool
scale_out
)
{
asm
volatile
(
"{
\n
"
".reg .pred p;
\n
"
"setp.ne.b32 p, %34, 0;
\n
"
"wgmma.mma_async.sync.aligned.m64n128k16.f16.f16.f16 "
"{%0, %1, %2, %3, %4, %5, %6, %7, "
"%8, %9, %10, %11, %12, %13, %14, %15, "
"%16, %17, %18, %19, %20, %21, %22, %23, "
"%24, %25, %26, %27, %28, %29, %30, %31}, "
"%32, %33, p, %35, %36, %37, %38;
\n
"
"}
\n
"
:
"+r"
(
c
[
0
]),
"+r"
(
c
[
1
]),
"+r"
(
c
[
2
]),
"+r"
(
c
[
3
]),
"+r"
(
c
[
4
]),
"+r"
(
c
[
5
]),
"+r"
(
c
[
6
]),
"+r"
(
c
[
7
]),
"+r"
(
c
[
8
]),
"+r"
(
c
[
9
]),
"+r"
(
c
[
10
]),
"+r"
(
c
[
11
]),
"+r"
(
c
[
12
]),
"+r"
(
c
[
13
]),
"+r"
(
c
[
14
]),
"+r"
(
c
[
15
]),
"+r"
(
c
[
16
]),
"+r"
(
c
[
17
]),
"+r"
(
c
[
18
]),
"+r"
(
c
[
19
]),
"+r"
(
c
[
20
]),
"+r"
(
c
[
21
]),
"+r"
(
c
[
22
]),
"+r"
(
c
[
23
]),
"+r"
(
c
[
24
]),
"+r"
(
c
[
25
]),
"+r"
(
c
[
26
]),
"+r"
(
c
[
27
]),
"+r"
(
c
[
28
]),
"+r"
(
c
[
29
]),
"+r"
(
c
[
30
]),
"+r"
(
c
[
31
])
:
"l"
(
desc_a
),
"l"
(
desc_b
),
"r"
(
int32_t
(
scale_out
)),
"n"
(
int32_t
(
scaleA
)),
"n"
(
int32_t
(
scaleB
)),
"n"
(
int32_t
(
tnspA
)),
"n"
(
int32_t
(
tnspB
)));
}
};
// M64N192K16 F16
template
<
size_t
...
Idx
>
template
<
bool
tnspA
,
bool
tnspB
,
int
scaleA
,
int
scaleB
>
TL_DEVICE
static
void
Run
(
uint64_t
desc_a
,
uint64_t
desc_b
,
CReg
*
c
,
struct
WgmmaSSImpl
<
DataType
::
kFloat16
,
DataType
::
kFloat16
,
DataType
::
kFloat16
,
cute
::
SM90
::
GMMA
::
ScaleOut
scale
,
64
,
192
,
16
,
tnspA
,
tnspB
,
scaleA
,
scaleB
>
{
std
::
index_sequence
<
Idx
...
>
)
{
TL_DEVICE
static
void
execute
(
uint64_t
desc_a
,
uint64_t
desc_b
,
uint32_t
*
c
,
Impl
::
fma
(
desc_a
,
desc_b
,
c
[
Idx
]...,
scale
);
bool
scale_out
)
{
asm
volatile
(
"{
\n
"
".reg .pred p;
\n
"
"setp.ne.b32 p, %50, 0;
\n
"
"wgmma.mma_async.sync.aligned.m64n192k16.f16.f16.f16 "
"{%0, %1, %2, %3, %4, %5, %6, %7, "
"%8, %9, %10, %11, %12, %13, %14, %15, "
"%16, %17, %18, %19, %20, %21, %22, %23, "
"%24, %25, %26, %27, %28, %29, %30, %31, "
"%32, %33, %34, %35, %36, %37, %38, %39, "
"%40, %41, %42, %43, %44, %45, %46, %47}, "
"%48, %49, p, %51, %52, %53, %54;
\n
"
"}
\n
"
:
"+r"
(
c
[
0
]),
"+r"
(
c
[
1
]),
"+r"
(
c
[
2
]),
"+r"
(
c
[
3
]),
"+r"
(
c
[
4
]),
"+r"
(
c
[
5
]),
"+r"
(
c
[
6
]),
"+r"
(
c
[
7
]),
"+r"
(
c
[
8
]),
"+r"
(
c
[
9
]),
"+r"
(
c
[
10
]),
"+r"
(
c
[
11
]),
"+r"
(
c
[
12
]),
"+r"
(
c
[
13
]),
"+r"
(
c
[
14
]),
"+r"
(
c
[
15
]),
"+r"
(
c
[
16
]),
"+r"
(
c
[
17
]),
"+r"
(
c
[
18
]),
"+r"
(
c
[
19
]),
"+r"
(
c
[
20
]),
"+r"
(
c
[
21
]),
"+r"
(
c
[
22
]),
"+r"
(
c
[
23
]),
"+r"
(
c
[
24
]),
"+r"
(
c
[
25
]),
"+r"
(
c
[
26
]),
"+r"
(
c
[
27
]),
"+r"
(
c
[
28
]),
"+r"
(
c
[
29
]),
"+r"
(
c
[
30
]),
"+r"
(
c
[
31
]),
"+r"
(
c
[
32
]),
"+r"
(
c
[
33
]),
"+r"
(
c
[
34
]),
"+r"
(
c
[
35
]),
"+r"
(
c
[
36
]),
"+r"
(
c
[
37
]),
"+r"
(
c
[
38
]),
"+r"
(
c
[
39
]),
"+r"
(
c
[
40
]),
"+r"
(
c
[
41
]),
"+r"
(
c
[
42
]),
"+r"
(
c
[
43
]),
"+r"
(
c
[
44
]),
"+r"
(
c
[
45
]),
"+r"
(
c
[
46
]),
"+r"
(
c
[
47
])
:
"l"
(
desc_a
),
"l"
(
desc_b
),
"r"
(
int32_t
(
scale_out
)),
"n"
(
int32_t
(
scaleA
)),
"n"
(
int32_t
(
scaleB
)),
"n"
(
int32_t
(
tnspA
)),
"n"
(
int32_t
(
tnspB
)));
}
}
};
// M64N256K16 F16
TL_DEVICE
static
void
exec
(
uint64_t
desc_a
,
uint64_t
desc_b
,
uint32_t
*
c_raw
,
template
<
bool
tnspA
,
bool
tnspB
,
int
scaleA
,
int
scaleB
>
bool
scale_out
)
{
struct
WgmmaSSImpl
<
DataType
::
kFloat16
,
DataType
::
kFloat16
,
DataType
::
kFloat16
,
auto
scale
=
scale_out
?
cute
::
SM90
::
GMMA
::
ScaleOut
::
One
64
,
256
,
16
,
tnspA
,
tnspB
,
scaleA
,
scaleB
>
{
:
cute
::
SM90
::
GMMA
::
ScaleOut
::
Zero
;
TL_DEVICE
static
void
execute
(
uint64_t
desc_a
,
uint64_t
desc_b
,
uint32_t
*
c
,
auto
c
=
reinterpret_cast
<
CReg
*>
(
c_raw
);
bool
scale_out
)
{
Run
(
desc_a
,
desc_b
,
c
,
scale
,
std
::
make_index_sequence
<
kCRegs
>
{});
asm
volatile
(
"{
\n
"
".reg .pred p;
\n
"
"setp.ne.b32 p, %66, 0;
\n
"
"wgmma.mma_async.sync.aligned.m64n256k16.f16.f16.f16 "
"{%0, %1, %2, %3, %4, %5, %6, %7, "
"%8, %9, %10, %11, %12, %13, %14, %15, "
"%16, %17, %18, %19, %20, %21, %22, %23, "
"%24, %25, %26, %27, %28, %29, %30, %31, "
"%32, %33, %34, %35, %36, %37, %38, %39, "
"%40, %41, %42, %43, %44, %45, %46, %47, "
"%48, %49, %50, %51, %52, %53, %54, %55, "
"%56, %57, %58, %59, %60, %61, %62, %63}, "
"%64, %65, p, %67, %68, %69, %70;
\n
"
"}
\n
"
:
"+r"
(
c
[
0
]),
"+r"
(
c
[
1
]),
"+r"
(
c
[
2
]),
"+r"
(
c
[
3
]),
"+r"
(
c
[
4
]),
"+r"
(
c
[
5
]),
"+r"
(
c
[
6
]),
"+r"
(
c
[
7
]),
"+r"
(
c
[
8
]),
"+r"
(
c
[
9
]),
"+r"
(
c
[
10
]),
"+r"
(
c
[
11
]),
"+r"
(
c
[
12
]),
"+r"
(
c
[
13
]),
"+r"
(
c
[
14
]),
"+r"
(
c
[
15
]),
"+r"
(
c
[
16
]),
"+r"
(
c
[
17
]),
"+r"
(
c
[
18
]),
"+r"
(
c
[
19
]),
"+r"
(
c
[
20
]),
"+r"
(
c
[
21
]),
"+r"
(
c
[
22
]),
"+r"
(
c
[
23
]),
"+r"
(
c
[
24
]),
"+r"
(
c
[
25
]),
"+r"
(
c
[
26
]),
"+r"
(
c
[
27
]),
"+r"
(
c
[
28
]),
"+r"
(
c
[
29
]),
"+r"
(
c
[
30
]),
"+r"
(
c
[
31
]),
"+r"
(
c
[
32
]),
"+r"
(
c
[
33
]),
"+r"
(
c
[
34
]),
"+r"
(
c
[
35
]),
"+r"
(
c
[
36
]),
"+r"
(
c
[
37
]),
"+r"
(
c
[
38
]),
"+r"
(
c
[
39
]),
"+r"
(
c
[
40
]),
"+r"
(
c
[
41
]),
"+r"
(
c
[
42
]),
"+r"
(
c
[
43
]),
"+r"
(
c
[
44
]),
"+r"
(
c
[
45
]),
"+r"
(
c
[
46
]),
"+r"
(
c
[
47
]),
"+r"
(
c
[
48
]),
"+r"
(
c
[
49
]),
"+r"
(
c
[
50
]),
"+r"
(
c
[
51
]),
"+r"
(
c
[
52
]),
"+r"
(
c
[
53
]),
"+r"
(
c
[
54
]),
"+r"
(
c
[
55
]),
"+r"
(
c
[
56
]),
"+r"
(
c
[
57
]),
"+r"
(
c
[
58
]),
"+r"
(
c
[
59
]),
"+r"
(
c
[
60
]),
"+r"
(
c
[
61
]),
"+r"
(
c
[
62
]),
"+r"
(
c
[
63
])
:
"l"
(
desc_a
),
"l"
(
desc_b
),
"r"
(
int32_t
(
scale_out
)),
"n"
(
int32_t
(
scaleA
)),
"n"
(
int32_t
(
scaleB
)),
"n"
(
int32_t
(
tnspA
)),
"n"
(
int32_t
(
tnspB
)));
}
}
};
};
// ================================= F16 x F16 -> F32
template
<
class
Impl
>
struct
CallWgmmaRS
{
// =================================
using
AReg
=
std
::
remove_extent_t
<
typename
Impl
::
ARegisters
>
;
using
CReg
=
std
::
remove_extent_t
<
typename
Impl
::
CRegisters
>
;
// M64N8K16 F16->F32
static
constexpr
int
kARegs
=
std
::
extent_v
<
typename
Impl
::
ARegisters
>
;
template
<
bool
tnspA
,
bool
tnspB
,
int
scaleA
,
int
scaleB
>
static
constexpr
int
kCRegs
=
std
::
extent_v
<
typename
Impl
::
CRegisters
>
;
struct
WgmmaSSImpl
<
DataType
::
kFloat16
,
DataType
::
kFloat16
,
DataType
::
kFloat32
,
static_assert
(
sizeof
(
AReg
)
==
sizeof
(
uint32_t
),
64
,
8
,
16
,
tnspA
,
tnspB
,
scaleA
,
scaleB
>
{
"tl::wgmma_rs expects 32-bit register operands for A."
);
TL_DEVICE
static
void
execute
(
uint64_t
desc_a
,
uint64_t
desc_b
,
uint32_t
*
c
,
static_assert
(
sizeof
(
CReg
)
==
sizeof
(
uint32_t
)
||
bool
scale_out
)
{
sizeof
(
CReg
)
==
sizeof
(
float
),
asm
volatile
(
"{
\n
"
"tl::wgmma_rs expects 32-bit accumulator registers."
);
".reg .pred p;
\n
"
"setp.ne.b32 p, %6, 0;
\n
"
template
<
size_t
...
AIdx
,
size_t
...
CIdx
>
"wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16 "
TL_DEVICE
static
void
"{%0, %1, %2, %3}, %4, %5, p, %7, %8, %9, %10;
\n
"
Run
(
const
AReg
*
a
,
uint64_t
desc_b
,
CReg
*
c
,
cute
::
SM90
::
GMMA
::
ScaleOut
scale
,
"}
\n
"
std
::
index_sequence
<
AIdx
...
>
,
std
::
index_sequence
<
CIdx
...
>
)
{
:
"+r"
(
c
[
0
]),
"+r"
(
c
[
1
]),
"+r"
(
c
[
2
]),
"+r"
(
c
[
3
])
Impl
::
fma
(
a
[
AIdx
]...,
desc_b
,
c
[
CIdx
]...,
scale
);
:
"l"
(
desc_a
),
"l"
(
desc_b
),
"r"
(
int32_t
(
scale_out
)),
"n"
(
int32_t
(
scaleA
)),
"n"
(
int32_t
(
scaleB
)),
"n"
(
int32_t
(
tnspA
)),
"n"
(
int32_t
(
tnspB
)));
}
}
};
// M64N16K16 F16->F32
TL_DEVICE
static
void
exec
(
const
uint32_t
*
a_raw
,
uint64_t
desc_b
,
template
<
bool
tnspA
,
bool
tnspB
,
int
scaleA
,
int
scaleB
>
uint32_t
*
c_raw
,
bool
scale_out
)
{
struct
WgmmaSSImpl
<
DataType
::
kFloat16
,
DataType
::
kFloat16
,
DataType
::
kFloat32
,
auto
scale
=
scale_out
?
cute
::
SM90
::
GMMA
::
ScaleOut
::
One
64
,
16
,
16
,
tnspA
,
tnspB
,
scaleA
,
scaleB
>
{
:
cute
::
SM90
::
GMMA
::
ScaleOut
::
Zero
;
TL_DEVICE
static
void
execute
(
uint64_t
desc_a
,
uint64_t
desc_b
,
uint32_t
*
c
,
auto
a
=
reinterpret_cast
<
const
AReg
*>
(
a_raw
);
bool
scale_out
)
{
auto
c
=
reinterpret_cast
<
CReg
*>
(
c_raw
);
asm
volatile
(
Run
(
a
,
desc_b
,
c
,
scale
,
std
::
make_index_sequence
<
kARegs
>
{},
"{
\n
"
std
::
make_index_sequence
<
kCRegs
>
{});
".reg .pred p;
\n
"
"setp.ne.b32 p, %10, 0;
\n
"
"wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16 "
"{%0, %1, %2, %3, %4, %5, %6, %7}, %8, %9, p, %11, %12, %13, %14;
\n
"
"}
\n
"
:
"+r"
(
c
[
0
]),
"+r"
(
c
[
1
]),
"+r"
(
c
[
2
]),
"+r"
(
c
[
3
]),
"+r"
(
c
[
4
]),
"+r"
(
c
[
5
]),
"+r"
(
c
[
6
]),
"+r"
(
c
[
7
])
:
"l"
(
desc_a
),
"l"
(
desc_b
),
"r"
(
int32_t
(
scale_out
)),
"n"
(
int32_t
(
scaleA
)),
"n"
(
int32_t
(
scaleB
)),
"n"
(
int32_t
(
tnspA
)),
"n"
(
int32_t
(
tnspB
)));
}
}
};
};
// M64N32K16 F16->F32
}
// namespace detail
template
<
bool
tnspA
,
bool
tnspB
,
int
scaleA
,
int
scaleB
>
struct
WgmmaSSImpl
<
DataType
::
kFloat16
,
DataType
::
kFloat16
,
DataType
::
kFloat32
,
64
,
32
,
16
,
tnspA
,
tnspB
,
scaleA
,
scaleB
>
{
TL_DEVICE
static
void
execute
(
uint64_t
desc_a
,
uint64_t
desc_b
,
uint32_t
*
c
,
bool
scale_out
)
{
asm
volatile
(
"{
\n
"
".reg .pred p;
\n
"
"setp.ne.b32 p, %18, 0;
\n
"
"wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16 "
"{%0, %1, %2, %3, %4, %5, %6, %7, "
"%8, %9, %10, %11, %12, %13, %14, %15}, "
"%16, %17, p, %19, %20, %21, %22;
\n
"
"}
\n
"
:
"+r"
(
c
[
0
]),
"+r"
(
c
[
1
]),
"+r"
(
c
[
2
]),
"+r"
(
c
[
3
]),
"+r"
(
c
[
4
]),
"+r"
(
c
[
5
]),
"+r"
(
c
[
6
]),
"+r"
(
c
[
7
]),
"+r"
(
c
[
8
]),
"+r"
(
c
[
9
]),
"+r"
(
c
[
10
]),
"+r"
(
c
[
11
]),
"+r"
(
c
[
12
]),
"+r"
(
c
[
13
]),
"+r"
(
c
[
14
]),
"+r"
(
c
[
15
])
:
"l"
(
desc_a
),
"l"
(
desc_b
),
"r"
(
int32_t
(
scale_out
)),
"n"
(
int32_t
(
scaleA
)),
"n"
(
int32_t
(
scaleB
)),
"n"
(
int32_t
(
tnspA
)),
"n"
(
int32_t
(
tnspB
)));
}
};
// M64N64K16 F16->F32
template
<
DataType
A_type
,
DataType
B_type
,
DataType
C_type
,
int
M
,
int
N
,
template
<
bool
tnspA
,
bool
tnspB
,
int
scaleA
,
int
scaleB
>
int
K
,
bool
tnspA
,
bool
tnspB
,
int
scaleA
,
int
scaleB
>
struct
WgmmaSSImpl
<
DataType
::
kFloat16
,
DataType
::
kFloat16
,
DataType
::
kFloat32
,
struct
WgmmaSSImpl
{
64
,
64
,
16
,
tnspA
,
tnspB
,
scaleA
,
scaleB
>
{
static_assert
(
detail
::
IsValidScale
<
scaleA
>
,
"tl::wgmma_ss: invalid scaleA"
);
TL_DEVICE
static
void
execute
(
uint64_t
desc_a
,
uint64_t
desc_b
,
uint32_t
*
c
,
static_assert
(
detail
::
IsValidScale
<
scaleB
>
,
"tl::wgmma_ss: invalid scaleB"
);
bool
scale_out
)
{
TL_DEVICE
static
void
execute
(
uint64_t
,
uint64_t
,
uint32_t
*
,
bool
)
{
asm
volatile
(
"{
\n
"
static_assert
(
always_false_v
<
std
::
integral_constant
<
int
,
M
>>
,
".reg .pred p;
\n
"
"tl::wgmma_ss: unsupported configuration"
);
"setp.ne.b32 p, %34, 0;
\n
"
"wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16 "
"{%0, %1, %2, %3, %4, %5, %6, %7, "
"%8, %9, %10, %11, %12, %13, %14, %15, "
"%16, %17, %18, %19, %20, %21, %22, %23, "
"%24, %25, %26, %27, %28, %29, %30, %31}, "
"%32, %33, p, %35, %36, %37, %38;
\n
"
"}
\n
"
:
"+r"
(
c
[
0
]),
"+r"
(
c
[
1
]),
"+r"
(
c
[
2
]),
"+r"
(
c
[
3
]),
"+r"
(
c
[
4
]),
"+r"
(
c
[
5
]),
"+r"
(
c
[
6
]),
"+r"
(
c
[
7
]),
"+r"
(
c
[
8
]),
"+r"
(
c
[
9
]),
"+r"
(
c
[
10
]),
"+r"
(
c
[
11
]),
"+r"
(
c
[
12
]),
"+r"
(
c
[
13
]),
"+r"
(
c
[
14
]),
"+r"
(
c
[
15
]),
"+r"
(
c
[
16
]),
"+r"
(
c
[
17
]),
"+r"
(
c
[
18
]),
"+r"
(
c
[
19
]),
"+r"
(
c
[
20
]),
"+r"
(
c
[
21
]),
"+r"
(
c
[
22
]),
"+r"
(
c
[
23
]),
"+r"
(
c
[
24
]),
"+r"
(
c
[
25
]),
"+r"
(
c
[
26
]),
"+r"
(
c
[
27
]),
"+r"
(
c
[
28
]),
"+r"
(
c
[
29
]),
"+r"
(
c
[
30
]),
"+r"
(
c
[
31
])
:
"l"
(
desc_a
),
"l"
(
desc_b
),
"r"
(
int32_t
(
scale_out
)),
"n"
(
int32_t
(
scaleA
)),
"n"
(
int32_t
(
scaleB
)),
"n"
(
int32_t
(
tnspA
)),
"n"
(
int32_t
(
tnspB
)));
}
}
};
};
// ================================= BF16 x BF16 -> F32
template
<
DataType
A_type
,
DataType
B_type
,
DataType
C_type
,
int
M
,
int
N
,
// =================================
int
K
,
bool
tnspA
,
bool
tnspB
,
int
scaleA
,
int
scaleB
>
struct
WgmmaRSImpl
{
// M64N8K16 BF16->F32
static_assert
(
detail
::
IsValidScale
<
scaleA
>
,
"tl::wgmma_rs: invalid scaleA"
);
template
<
bool
tnspA
,
bool
tnspB
,
int
scaleA
,
int
scaleB
>
static_assert
(
detail
::
IsValidScale
<
scaleB
>
,
"tl::wgmma_rs: invalid scaleB"
);
struct
WgmmaSSImpl
<
DataType
::
kBFloat16
,
DataType
::
kBFloat16
,
DataType
::
kFloat32
,
TL_DEVICE
static
void
execute
(
const
uint32_t
*
,
uint64_t
,
uint32_t
*
,
bool
)
{
64
,
8
,
16
,
tnspA
,
tnspB
,
scaleA
,
scaleB
>
{
static_assert
(
always_false_v
<
std
::
integral_constant
<
int
,
M
>>
,
TL_DEVICE
static
void
execute
(
uint64_t
desc_a
,
uint64_t
desc_b
,
uint32_t
*
c
,
"tl::wgmma_rs: unsupported configuration"
);
bool
scale_out
)
{
asm
volatile
(
"{
\n
"
".reg .pred p;
\n
"
"setp.ne.b32 p, %6, 0;
\n
"
"wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16 "
"{%0, %1, %2, %3}, %4, %5, p, %7, %8, %9, %10;
\n
"
"}
\n
"
:
"+r"
(
c
[
0
]),
"+r"
(
c
[
1
]),
"+r"
(
c
[
2
]),
"+r"
(
c
[
3
])
:
"l"
(
desc_a
),
"l"
(
desc_b
),
"r"
(
int32_t
(
scale_out
)),
"n"
(
int32_t
(
scaleA
)),
"n"
(
int32_t
(
scaleB
)),
"n"
(
int32_t
(
tnspA
)),
"n"
(
int32_t
(
tnspB
)));
}
}
};
};
// M64N16K16 BF16->F32
#define TL_WGMMA_DEFINE_SS_GENERAL(AType, BType, CType, M, N, K, ImplName) \
template
<
bool
tnspA
,
bool
tnspB
,
int
scaleA
,
int
scaleB
>
template <bool tnspA, bool tnspB, int scaleA, int scaleB> \
struct
WgmmaSSImpl
<
DataType
::
kBFloat16
,
DataType
::
kBFloat16
,
DataType
::
kFloat32
,
struct WgmmaSSImpl<DataType::AType, DataType::BType, DataType::CType, M, N, \
64
,
16
,
16
,
tnspA
,
tnspB
,
scaleA
,
scaleB
>
{
K, tnspA, tnspB, scaleA, scaleB> { \
TL_DEVICE
static
void
execute
(
uint64_t
desc_a
,
uint64_t
desc_b
,
uint32_t
*
c
,
static_assert(detail::IsValidScale<scaleA>, \
bool
scale_out
)
{
"tl::wgmma_ss: invalid scaleA"); \
asm
volatile
(
static_assert(detail::IsValidScale<scaleB>, \
"{
\n
"
"tl::wgmma_ss: invalid scaleB"); \
".reg .pred p;
\n
"
using Impl = \
"setp.ne.b32 p, %10, 0;
\n
"
cute::SM90::GMMA::ImplName<detail::MajorValue<tnspA>::value, \
"wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16 "
detail::MajorValue<tnspB>::value, \
"{%0, %1, %2, %3, %4, %5, %6, %7}, %8, %9, p, %11, %12, %13, %14;
\n
"
detail::ScaleInValue<scaleA>::value, \
"}
\n
"
detail::ScaleInValue<scaleB>::value>; \
:
"+r"
(
c
[
0
]),
"+r"
(
c
[
1
]),
"+r"
(
c
[
2
]),
"+r"
(
c
[
3
]),
"+r"
(
c
[
4
]),
TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, \
"+r"
(
c
[
5
]),
"+r"
(
c
[
6
]),
"+r"
(
c
[
7
])
uint32_t *c, bool scale_out) { \
:
"l"
(
desc_a
),
"l"
(
desc_b
),
"r"
(
int32_t
(
scale_out
)),
detail::CallWgmmaSS<Impl>::exec(desc_a, desc_b, c, scale_out); \
"n"
(
int32_t
(
scaleA
)),
"n"
(
int32_t
(
scaleB
)),
"n"
(
int32_t
(
tnspA
)),
} \
"n"
(
int32_t
(
tnspB
)));
};
}
};
// ================================= TF32 x TF32 -> F32
#define TL_WGMMA_DEFINE_SS_TN(AType, BType, CType, M, N, K, ImplName) \
// =================================
template <int scaleA, int scaleB> \
struct WgmmaSSImpl<DataType::AType, DataType::BType, DataType::CType, M, N, \
// M64N8K8 TF32->F32
K, false, false, scaleA, scaleB> { \
template
<
bool
tnspA
,
bool
tnspB
,
int
scaleA
,
int
scaleB
>
static_assert(detail::IsValidScale<scaleA>, \
struct
WgmmaSSImpl
<
DataType
::
kTensorFloat32
,
DataType
::
kTensorFloat32
,
"tl::wgmma_ss: invalid scaleA"); \
DataType
::
kFloat32
,
64
,
8
,
8
,
tnspA
,
tnspB
,
scaleA
,
scaleB
>
{
static_assert(detail::IsValidScale<scaleB>, \
TL_DEVICE
static
void
execute
(
uint64_t
desc_a
,
uint64_t
desc_b
,
uint32_t
*
c
,
"tl::wgmma_ss: invalid scaleB"); \
bool
scale_out
)
{
using Impl = \
asm
volatile
(
"{
\n
"
cute::SM90::GMMA::ImplName<detail::ScaleInValue<scaleA>::value, \
".reg .pred p;
\n
"
detail::ScaleInValue<scaleB>::value>; \
"setp.ne.b32 p, %6, 0;
\n
"
TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, \
"wgmma.mma_async.sync.aligned.m64n8k8.f32.tf32.tf32 "
uint32_t *c, bool scale_out) { \
"{%0, %1, %2, %3}, %4, %5, p, %7, %8, %9, %10;
\n
"
detail::CallWgmmaSS<Impl>::exec(desc_a, desc_b, c, scale_out); \
"}
\n
"
} \
:
"+r"
(
c
[
0
]),
"+r"
(
c
[
1
]),
"+r"
(
c
[
2
]),
"+r"
(
c
[
3
])
};
:
"l"
(
desc_a
),
"l"
(
desc_b
),
"r"
(
int32_t
(
scale_out
)),
"n"
(
int32_t
(
scaleA
)),
"n"
(
int32_t
(
scaleB
)),
"n"
(
int32_t
(
tnspA
)),
"n"
(
int32_t
(
tnspB
)));
}
};
// M64N16K8 TF32->F32
#define TL_WGMMA_DEFINE_SS_TN_FIXED_SCALE(AType, BType, CType, M, N, K, \
template
<
bool
tnspA
,
bool
tnspB
,
int
scaleA
,
int
scaleB
>
ImplName) \
struct
WgmmaSSImpl
<
DataType
::
kTensorFloat32
,
DataType
::
kTensorFloat32
,
template <int scaleA, int scaleB> \
DataType
::
kFloat32
,
64
,
16
,
8
,
tnspA
,
tnspB
,
scaleA
,
struct WgmmaSSImpl<DataType::AType, DataType::BType, DataType::CType, M, N, \
scaleB
>
{
K, false, false, scaleA, scaleB> { \
TL_DEVICE
static
void
execute
(
uint64_t
desc_a
,
uint64_t
desc_b
,
uint32_t
*
c
,
static_assert(detail::IsValidScale<scaleA>, \
bool
scale_out
)
{
"tl::wgmma_ss: invalid scaleA"); \
asm
volatile
(
static_assert(detail::IsValidScale<scaleB>, \
"{
\n
"
"tl::wgmma_ss: invalid scaleB"); \
".reg .pred p;
\n
"
static_assert(scaleA == 1 && scaleB == 1, \
"setp.ne.b32 p, %10, 0;
\n
"
"tl::wgmma_ss: only +1 scaling supported for this WGMMA"); \
"wgmma.mma_async.sync.aligned.m64n16k8.f32.tf32.tf32 "
using Impl = cute::SM90::GMMA::ImplName; \
"{%0, %1, %2, %3, %4, %5, %6, %7}, %8, %9, p, %11, %12, %13, %14;
\n
"
TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, \
"}
\n
"
uint32_t *c, bool scale_out) { \
:
"+r"
(
c
[
0
]),
"+r"
(
c
[
1
]),
"+r"
(
c
[
2
]),
"+r"
(
c
[
3
]),
"+r"
(
c
[
4
]),
detail::CallWgmmaSS<Impl>::exec(desc_a, desc_b, c, scale_out); \
"+r"
(
c
[
5
]),
"+r"
(
c
[
6
]),
"+r"
(
c
[
7
])
} \
:
"l"
(
desc_a
),
"l"
(
desc_b
),
"r"
(
int32_t
(
scale_out
)),
};
"n"
(
int32_t
(
scaleA
)),
"n"
(
int32_t
(
scaleB
)),
"n"
(
int32_t
(
tnspA
)),
"n"
(
int32_t
(
tnspB
)));
}
};
// ================================= INT8 x INT8 -> INT32
#define TL_WGMMA_DEFINE_RS_GENERAL(AType, BType, CType, M, N, K, ImplName) \
// =================================
template <bool tnspA, bool tnspB, int scaleA, int scaleB> \
struct WgmmaRSImpl<DataType::AType, DataType::BType, DataType::CType, M, N, \
// M64N8K32 S8->S32
K, tnspA, tnspB, scaleA, scaleB> { \
template
<
bool
tnspA
,
bool
tnspB
,
int
scaleA
,
int
scaleB
>
static_assert(!tnspA, "tl::wgmma_rs: operand A must be K-major"); \
struct
WgmmaSSImpl
<
DataType
::
kInt8
,
DataType
::
kInt8
,
DataType
::
kInt32
,
64
,
8
,
static_assert(detail::IsValidScale<scaleA>, \
32
,
tnspA
,
tnspB
,
scaleA
,
scaleB
>
{
"tl::wgmma_rs: invalid scaleA"); \
TL_DEVICE
static
void
execute
(
uint64_t
desc_a
,
uint64_t
desc_b
,
uint32_t
*
c
,
static_assert(detail::IsValidScale<scaleB>, \
bool
scale_out
)
{
"tl::wgmma_rs: invalid scaleB"); \
asm
volatile
(
"{
\n
"
using Impl = \
".reg .pred p;
\n
"
cute::SM90::GMMA::ImplName<detail::MajorValue<tnspA>::value, \
"setp.ne.b32 p, %4, 0;
\n
"
detail::MajorValue<tnspB>::value, \
"wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.s8 "
detail::ScaleInValue<scaleA>::value, \
"{%0, %1}, %2, %3, p, %5, %6, %7, %8;
\n
"
detail::ScaleInValue<scaleB>::value>; \
"}
\n
"
TL_DEVICE static void execute(const uint32_t *a, uint64_t desc_b, \
:
"+r"
(
c
[
0
]),
"+r"
(
c
[
1
])
uint32_t *c, bool scale_out) { \
:
"l"
(
desc_a
),
"l"
(
desc_b
),
"r"
(
int32_t
(
scale_out
)),
detail::CallWgmmaRS<Impl>::exec(a, desc_b, c, scale_out); \
"n"
(
int32_t
(
scaleA
)),
"n"
(
int32_t
(
scaleB
)),
} \
"n"
(
int32_t
(
tnspA
)),
"n"
(
int32_t
(
tnspB
)));
};
}
};
// M64N16K32 S8->S32
#define TL_WGMMA_DEFINE_RS_TN(AType, BType, CType, M, N, K, ImplName) \
template
<
bool
tnspA
,
bool
tnspB
,
int
scaleA
,
int
scaleB
>
template <int scaleA, int scaleB> \
struct
WgmmaSSImpl
<
DataType
::
kInt8
,
DataType
::
kInt8
,
DataType
::
kInt32
,
64
,
16
,
struct WgmmaRSImpl<DataType::AType, DataType::BType, DataType::CType, M, N, \
32
,
tnspA
,
tnspB
,
scaleA
,
scaleB
>
{
K, false, false, scaleA, scaleB> { \
TL_DEVICE
static
void
execute
(
uint64_t
desc_a
,
uint64_t
desc_b
,
uint32_t
*
c
,
static_assert(detail::IsValidScale<scaleA>, \
bool
scale_out
)
{
"tl::wgmma_rs: invalid scaleA"); \
asm
volatile
(
"{
\n
"
static_assert(detail::IsValidScale<scaleB>, \
".reg .pred p;
\n
"
"tl::wgmma_rs: invalid scaleB"); \
"setp.ne.b32 p, %6, 0;
\n
"
using Impl = \
"wgmma.mma_async.sync.aligned.m64n16k32.s32.s8.s8 "
cute::SM90::GMMA::ImplName<detail::ScaleInValue<scaleA>::value, \
"{%0, %1, %2, %3}, %4, %5, p, %7, %8, %9, %10;
\n
"
detail::ScaleInValue<scaleB>::value>; \
"}
\n
"
TL_DEVICE static void execute(const uint32_t *a, uint64_t desc_b, \
:
"+r"
(
c
[
0
]),
"+r"
(
c
[
1
]),
"+r"
(
c
[
2
]),
"+r"
(
c
[
3
])
uint32_t *c, bool scale_out) { \
:
"l"
(
desc_a
),
"l"
(
desc_b
),
"r"
(
int32_t
(
scale_out
)),
detail::CallWgmmaRS<Impl>::exec(a, desc_b, c, scale_out); \
"n"
(
int32_t
(
scaleA
)),
"n"
(
int32_t
(
scaleB
)),
} \
"n"
(
int32_t
(
tnspA
)),
"n"
(
int32_t
(
tnspB
)));
};
}
};
// ================================= FP8 x FP8 -> F16/F32
#define TL_WGMMA_DEFINE_RS_TN_FIXED_SCALE(AType, BType, CType, M, N, K, \
// =================================
ImplName) \
template <int scaleA, int scaleB> \
// M64N8K32 E4M3->F16
struct WgmmaRSImpl<DataType::AType, DataType::BType, DataType::CType, M, N, \
template
<
bool
tnspA
,
bool
tnspB
,
int
scaleA
,
int
scaleB
>
K, false, false, scaleA, scaleB> { \
struct
WgmmaSSImpl
<
DataType
::
kFloat8_e4m3
,
DataType
::
kFloat8_e4m3
,
static_assert(detail::IsValidScale<scaleA>, \
DataType
::
kFloat16
,
64
,
8
,
32
,
tnspA
,
tnspB
,
scaleA
,
"tl::wgmma_rs: invalid scaleA"); \
scaleB
>
{
static_assert(detail::IsValidScale<scaleB>, \
TL_DEVICE
static
void
execute
(
uint64_t
desc_a
,
uint64_t
desc_b
,
uint32_t
*
c
,
"tl::wgmma_rs: invalid scaleB"); \
bool
scale_out
)
{
static_assert(scaleA == 1 && scaleB == 1, \
asm
volatile
(
"{
\n
"
"tl::wgmma_rs: only +1 scaling supported for this WGMMA"); \
".reg .pred p;
\n
"
using Impl = cute::SM90::GMMA::ImplName; \
"setp.ne.b32 p, %4, 0;
\n
"
TL_DEVICE static void execute(const uint32_t *a, uint64_t desc_b, \
"wgmma.mma_async.sync.aligned.m64n8k32.f16.e4m3.e4m3 "
uint32_t *c, bool scale_out) { \
"{%0, %1}, %2, %3, p, %5, %6, %7, %8;
\n
"
detail::CallWgmmaRS<Impl>::exec(a, desc_b, c, scale_out); \
"}
\n
"
} \
:
"+r"
(
c
[
0
]),
"+r"
(
c
[
1
])
};
:
"l"
(
desc_a
),
"l"
(
desc_b
),
"r"
(
int32_t
(
scale_out
)),
"n"
(
int32_t
(
scaleA
)),
"n"
(
int32_t
(
scaleB
)),
"n"
(
int32_t
(
tnspA
)),
"n"
(
int32_t
(
tnspB
)));
}
};
// M64N8K32 E4M3->F32
#define TL_WGMMA_FOREACH_N_FLOAT_MUL8(OP) \
template
<
bool
tnspA
,
bool
tnspB
,
int
scaleA
,
int
scaleB
>
OP(8) \
struct
WgmmaSSImpl
<
DataType
::
kFloat8_e4m3
,
DataType
::
kFloat8_e4m3
,
OP(16) \
DataType
::
kFloat32
,
64
,
8
,
32
,
tnspA
,
tnspB
,
scaleA
,
OP(24) \
scaleB
>
{
OP(32) \
TL_DEVICE
static
void
execute
(
uint64_t
desc_a
,
uint64_t
desc_b
,
uint32_t
*
c
,
OP(40) \
bool
scale_out
)
{
OP(48) \
asm
volatile
(
"{
\n
"
OP(56) \
".reg .pred p;
\n
"
OP(64) \
"setp.ne.b32 p, %6, 0;
\n
"
OP(72) \
"wgmma.mma_async.sync.aligned.m64n8k32.f32.e4m3.e4m3 "
OP(80) \
"{%0, %1, %2, %3}, %4, %5, p, %7, %8, %9, %10;
\n
"
OP(88) \
"}
\n
"
OP(96) \
:
"+r"
(
c
[
0
]),
"+r"
(
c
[
1
]),
"+r"
(
c
[
2
]),
"+r"
(
c
[
3
])
OP(104) \
:
"l"
(
desc_a
),
"l"
(
desc_b
),
"r"
(
int32_t
(
scale_out
)),
OP(112) \
"n"
(
int32_t
(
scaleA
)),
"n"
(
int32_t
(
scaleB
)),
OP(120) \
"n"
(
int32_t
(
tnspA
)),
"n"
(
int32_t
(
tnspB
)));
OP(128) \
}
OP(136) \
};
OP(144) \
OP(152) \
OP(160) \
OP(168) \
OP(176) \
OP(184) \
OP(192) \
OP(200) \
OP(208) \
OP(216) \
OP(224) \
OP(232) \
OP(240) \
OP(248) \
OP(256)
#define TL_WGMMA_FOREACH_N_INT32_MUL8(OP) \
OP(8) \
OP(16) \
OP(24) \
OP(32) \
OP(48) \
OP(64) \
OP(80) \
OP(96) \
OP(112) \
OP(128) \
OP(144) \
OP(160) \
OP(176) \
OP(192) \
OP(208) \
OP(224) \
OP(240) \
OP(256)
#define TL_WGMMA_DEFINE_F16_F16_F16_SS(N) \
TL_WGMMA_DEFINE_SS_GENERAL(kFloat16, kFloat16, kFloat16, 64, N, 16, \
MMA_64x##N##x16_F16F16F16_SS)
#define TL_WGMMA_DEFINE_F16_F16_F32_SS(N) \
TL_WGMMA_DEFINE_SS_GENERAL(kFloat16, kFloat16, kFloat32, 64, N, 16, \
MMA_64x##N##x16_F32F16F16_SS)
#define TL_WGMMA_DEFINE_BF16_BF16_F32_SS(N) \
TL_WGMMA_DEFINE_SS_GENERAL(kBFloat16, kBFloat16, kFloat32, 64, N, 16, \
MMA_64x##N##x16_F32BF16BF16_SS)
#define TL_WGMMA_DEFINE_F32_TF32_SS_TN(N) \
TL_WGMMA_DEFINE_SS_TN(kTensorFloat32, kTensorFloat32, kFloat32, 64, N, 8, \
MMA_64x##N##x8_F32TF32TF32_SS_TN)
#define TL_WGMMA_DEFINE_S32_S8S8_SS_TN(N) \
TL_WGMMA_DEFINE_SS_TN_FIXED_SCALE(kInt8, kInt8, kInt32, 64, N, 32, \
MMA_64x##N##x32_S32S8S8_SS_TN)
#define TL_WGMMA_DEFINE_S32_S8U8_SS_TN(N) \
TL_WGMMA_DEFINE_SS_TN_FIXED_SCALE(kInt8, kUInt8, kInt32, 64, N, 32, \
MMA_64x##N##x32_S32S8U8_SS_TN)
#define TL_WGMMA_DEFINE_S32_U8S8_SS_TN(N) \
TL_WGMMA_DEFINE_SS_TN_FIXED_SCALE(kUInt8, kInt8, kInt32, 64, N, 32, \
MMA_64x##N##x32_S32U8S8_SS_TN)
#define TL_WGMMA_DEFINE_S32_U8U8_SS_TN(N) \
TL_WGMMA_DEFINE_SS_TN_FIXED_SCALE(kUInt8, kUInt8, kInt32, 64, N, 32, \
MMA_64x##N##x32_S32U8U8_SS_TN)
#define TL_WGMMA_DEFINE_F16_E4M3E4M3_SS_TN(N) \
TL_WGMMA_DEFINE_SS_TN(kFloat8_e4m3, kFloat8_e4m3, kFloat16, 64, N, 32, \
MMA_64x##N##x32_F16E4M3E4M3_SS_TN)
#define TL_WGMMA_DEFINE_F32_E4M3E4M3_SS_TN(N) \
TL_WGMMA_DEFINE_SS_TN(kFloat8_e4m3, kFloat8_e4m3, kFloat32, 64, N, 32, \
MMA_64x##N##x32_F32E4M3E4M3_SS_TN)
#define TL_WGMMA_DEFINE_F16_E4M3E5M2_SS_TN(N) \
TL_WGMMA_DEFINE_SS_TN(kFloat8_e4m3, kFloat8_e5m2, kFloat16, 64, N, 32, \
MMA_64x##N##x32_F16E4M3E5M2_SS_TN)
#define TL_WGMMA_DEFINE_F32_E4M3E5M2_SS_TN(N) \
TL_WGMMA_DEFINE_SS_TN(kFloat8_e4m3, kFloat8_e5m2, kFloat32, 64, N, 32, \
MMA_64x##N##x32_F32E4M3E5M2_SS_TN)
#define TL_WGMMA_DEFINE_F16_E5M2E4M3_SS_TN(N) \
TL_WGMMA_DEFINE_SS_TN(kFloat8_e5m2, kFloat8_e4m3, kFloat16, 64, N, 32, \
MMA_64x##N##x32_F16E5M2E4M3_SS_TN)
#define TL_WGMMA_DEFINE_F32_E5M2E4M3_SS_TN(N) \
TL_WGMMA_DEFINE_SS_TN(kFloat8_e5m2, kFloat8_e4m3, kFloat32, 64, N, 32, \
MMA_64x##N##x32_F32E5M2E4M3_SS_TN)
#define TL_WGMMA_DEFINE_F16_E5M2E5M2_SS_TN(N) \
TL_WGMMA_DEFINE_SS_TN(kFloat8_e5m2, kFloat8_e5m2, kFloat16, 64, N, 32, \
MMA_64x##N##x32_F16E5M2E5M2_SS_TN)
#define TL_WGMMA_DEFINE_F32_E5M2E5M2_SS_TN(N) \
TL_WGMMA_DEFINE_SS_TN(kFloat8_e5m2, kFloat8_e5m2, kFloat32, 64, N, 32, \
MMA_64x##N##x32_F32E5M2E5M2_SS_TN)
TL_WGMMA_FOREACH_N_FLOAT_MUL8
(
TL_WGMMA_DEFINE_F16_F16_F16_SS
);
TL_WGMMA_FOREACH_N_FLOAT_MUL8
(
TL_WGMMA_DEFINE_F16_F16_F32_SS
);
TL_WGMMA_FOREACH_N_FLOAT_MUL8
(
TL_WGMMA_DEFINE_BF16_BF16_F32_SS
);
TL_WGMMA_FOREACH_N_FLOAT_MUL8
(
TL_WGMMA_DEFINE_F32_TF32_SS_TN
);
TL_WGMMA_FOREACH_N_INT32_MUL8
(
TL_WGMMA_DEFINE_S32_S8S8_SS_TN
);
TL_WGMMA_FOREACH_N_INT32_MUL8
(
TL_WGMMA_DEFINE_S32_S8U8_SS_TN
);
TL_WGMMA_FOREACH_N_INT32_MUL8
(
TL_WGMMA_DEFINE_S32_U8S8_SS_TN
);
TL_WGMMA_FOREACH_N_INT32_MUL8
(
TL_WGMMA_DEFINE_S32_U8U8_SS_TN
);
TL_WGMMA_FOREACH_N_FLOAT_MUL8
(
TL_WGMMA_DEFINE_F16_E4M3E4M3_SS_TN
);
TL_WGMMA_FOREACH_N_FLOAT_MUL8
(
TL_WGMMA_DEFINE_F32_E4M3E4M3_SS_TN
);
TL_WGMMA_FOREACH_N_FLOAT_MUL8
(
TL_WGMMA_DEFINE_F16_E4M3E5M2_SS_TN
);
TL_WGMMA_FOREACH_N_FLOAT_MUL8
(
TL_WGMMA_DEFINE_F32_E4M3E5M2_SS_TN
);
TL_WGMMA_FOREACH_N_FLOAT_MUL8
(
TL_WGMMA_DEFINE_F16_E5M2E4M3_SS_TN
);
TL_WGMMA_FOREACH_N_FLOAT_MUL8
(
TL_WGMMA_DEFINE_F32_E5M2E4M3_SS_TN
);
TL_WGMMA_FOREACH_N_FLOAT_MUL8
(
TL_WGMMA_DEFINE_F16_E5M2E5M2_SS_TN
);
TL_WGMMA_FOREACH_N_FLOAT_MUL8
(
TL_WGMMA_DEFINE_F32_E5M2E5M2_SS_TN
);
#define TL_WGMMA_DEFINE_F16_F16_F16_RS(N) \
TL_WGMMA_DEFINE_RS_GENERAL(kFloat16, kFloat16, kFloat16, 64, N, 16, \
MMA_64x##N##x16_F16F16F16_RS)
#define TL_WGMMA_DEFINE_F16_F16_F32_RS(N) \
TL_WGMMA_DEFINE_RS_GENERAL(kFloat16, kFloat16, kFloat32, 64, N, 16, \
MMA_64x##N##x16_F32F16F16_RS)
#define TL_WGMMA_DEFINE_BF16_BF16_F32_RS(N) \
TL_WGMMA_DEFINE_RS_GENERAL(kBFloat16, kBFloat16, kFloat32, 64, N, 16, \
MMA_64x##N##x16_F32BF16BF16_RS)
#define TL_WGMMA_DEFINE_F32_TF32_RS_TN(N) \
TL_WGMMA_DEFINE_RS_TN(kTensorFloat32, kTensorFloat32, kFloat32, 64, N, 8, \
MMA_64x##N##x8_F32TF32TF32_RS_TN)
#define TL_WGMMA_DEFINE_S32_S8S8_RS_TN(N) \
TL_WGMMA_DEFINE_RS_TN_FIXED_SCALE(kInt8, kInt8, kInt32, 64, N, 32, \
MMA_64x##N##x32_S32S8S8_RS_TN)
#define TL_WGMMA_DEFINE_S32_S8U8_RS_TN(N) \
TL_WGMMA_DEFINE_RS_TN_FIXED_SCALE(kInt8, kUInt8, kInt32, 64, N, 32, \
MMA_64x##N##x32_S32S8U8_RS_TN)
#define TL_WGMMA_DEFINE_S32_U8S8_RS_TN(N) \
TL_WGMMA_DEFINE_RS_TN_FIXED_SCALE(kUInt8, kInt8, kInt32, 64, N, 32, \
MMA_64x##N##x32_S32U8S8_RS_TN)
#define TL_WGMMA_DEFINE_S32_U8U8_RS_TN(N) \
TL_WGMMA_DEFINE_RS_TN_FIXED_SCALE(kUInt8, kUInt8, kInt32, 64, N, 32, \
MMA_64x##N##x32_S32U8U8_RS_TN)
#define TL_WGMMA_DEFINE_F16_E4M3E4M3_RS_TN(N) \
TL_WGMMA_DEFINE_RS_TN(kFloat8_e4m3, kFloat8_e4m3, kFloat16, 64, N, 32, \
MMA_64x##N##x32_F16E4M3E4M3_RS_TN)
#define TL_WGMMA_DEFINE_F32_E4M3E4M3_RS_TN(N) \
TL_WGMMA_DEFINE_RS_TN(kFloat8_e4m3, kFloat8_e4m3, kFloat32, 64, N, 32, \
MMA_64x##N##x32_F32E4M3E4M3_RS_TN)
#define TL_WGMMA_DEFINE_F16_E4M3E5M2_RS_TN(N) \
TL_WGMMA_DEFINE_RS_TN(kFloat8_e4m3, kFloat8_e5m2, kFloat16, 64, N, 32, \
MMA_64x##N##x32_F16E4M3E5M2_RS_TN)
#define TL_WGMMA_DEFINE_F32_E4M3E5M2_RS_TN(N) \
TL_WGMMA_DEFINE_RS_TN(kFloat8_e4m3, kFloat8_e5m2, kFloat32, 64, N, 32, \
MMA_64x##N##x32_F32E4M3E5M2_RS_TN)
#define TL_WGMMA_DEFINE_F16_E5M2E4M3_RS_TN(N) \
TL_WGMMA_DEFINE_RS_TN(kFloat8_e5m2, kFloat8_e4m3, kFloat16, 64, N, 32, \
MMA_64x##N##x32_F16E5M2E4M3_RS_TN)
#define TL_WGMMA_DEFINE_F32_E5M2E4M3_RS_TN(N) \
TL_WGMMA_DEFINE_RS_TN(kFloat8_e5m2, kFloat8_e4m3, kFloat32, 64, N, 32, \
MMA_64x##N##x32_F32E5M2E4M3_RS_TN)
#define TL_WGMMA_DEFINE_F16_E5M2E5M2_RS_TN(N) \
TL_WGMMA_DEFINE_RS_TN(kFloat8_e5m2, kFloat8_e5m2, kFloat16, 64, N, 32, \
MMA_64x##N##x32_F16E5M2E5M2_RS_TN)
#define TL_WGMMA_DEFINE_F32_E5M2E5M2_RS_TN(N) \
TL_WGMMA_DEFINE_RS_TN(kFloat8_e5m2, kFloat8_e5m2, kFloat32, 64, N, 32, \
MMA_64x##N##x32_F32E5M2E5M2_RS_TN)
TL_WGMMA_FOREACH_N_FLOAT_MUL8
(
TL_WGMMA_DEFINE_F16_F16_F16_RS
);
TL_WGMMA_FOREACH_N_FLOAT_MUL8
(
TL_WGMMA_DEFINE_F16_F16_F32_RS
);
TL_WGMMA_FOREACH_N_FLOAT_MUL8
(
TL_WGMMA_DEFINE_BF16_BF16_F32_RS
);
TL_WGMMA_FOREACH_N_FLOAT_MUL8
(
TL_WGMMA_DEFINE_F32_TF32_RS_TN
);
TL_WGMMA_FOREACH_N_INT32_MUL8
(
TL_WGMMA_DEFINE_S32_S8S8_RS_TN
);
TL_WGMMA_FOREACH_N_INT32_MUL8
(
TL_WGMMA_DEFINE_S32_S8U8_RS_TN
);
TL_WGMMA_FOREACH_N_INT32_MUL8
(
TL_WGMMA_DEFINE_S32_U8S8_RS_TN
);
TL_WGMMA_FOREACH_N_INT32_MUL8
(
TL_WGMMA_DEFINE_S32_U8U8_RS_TN
);
TL_WGMMA_FOREACH_N_FLOAT_MUL8
(
TL_WGMMA_DEFINE_F16_E4M3E4M3_RS_TN
);
TL_WGMMA_FOREACH_N_FLOAT_MUL8
(
TL_WGMMA_DEFINE_F32_E4M3E4M3_RS_TN
);
TL_WGMMA_FOREACH_N_FLOAT_MUL8
(
TL_WGMMA_DEFINE_F16_E4M3E5M2_RS_TN
);
TL_WGMMA_FOREACH_N_FLOAT_MUL8
(
TL_WGMMA_DEFINE_F32_E4M3E5M2_RS_TN
);
TL_WGMMA_FOREACH_N_FLOAT_MUL8
(
TL_WGMMA_DEFINE_F16_E5M2E4M3_RS_TN
);
TL_WGMMA_FOREACH_N_FLOAT_MUL8
(
TL_WGMMA_DEFINE_F32_E5M2E4M3_RS_TN
);
TL_WGMMA_FOREACH_N_FLOAT_MUL8
(
TL_WGMMA_DEFINE_F16_E5M2E5M2_RS_TN
);
TL_WGMMA_FOREACH_N_FLOAT_MUL8
(
TL_WGMMA_DEFINE_F32_E5M2E5M2_RS_TN
);
#undef TL_WGMMA_DEFINE_F16_F16_F16_SS
#undef TL_WGMMA_DEFINE_F16_F16_F32_SS
#undef TL_WGMMA_DEFINE_BF16_BF16_F32_SS
#undef TL_WGMMA_DEFINE_F32_TF32_SS_TN
#undef TL_WGMMA_DEFINE_S32_S8S8_SS_TN
#undef TL_WGMMA_DEFINE_S32_S8U8_SS_TN
#undef TL_WGMMA_DEFINE_S32_U8S8_SS_TN
#undef TL_WGMMA_DEFINE_S32_U8U8_SS_TN
#undef TL_WGMMA_DEFINE_F16_E4M3E4M3_SS_TN
#undef TL_WGMMA_DEFINE_F32_E4M3E4M3_SS_TN
#undef TL_WGMMA_DEFINE_F16_E4M3E5M2_SS_TN
#undef TL_WGMMA_DEFINE_F32_E4M3E5M2_SS_TN
#undef TL_WGMMA_DEFINE_F16_E5M2E4M3_SS_TN
#undef TL_WGMMA_DEFINE_F32_E5M2E4M3_SS_TN
#undef TL_WGMMA_DEFINE_F16_E5M2E5M2_SS_TN
#undef TL_WGMMA_DEFINE_F32_E5M2E5M2_SS_TN
#undef TL_WGMMA_DEFINE_F16_F16_F16_RS
#undef TL_WGMMA_DEFINE_F16_F16_F32_RS
#undef TL_WGMMA_DEFINE_BF16_BF16_F32_RS
#undef TL_WGMMA_DEFINE_F32_TF32_RS_TN
#undef TL_WGMMA_DEFINE_S32_S8S8_RS_TN
#undef TL_WGMMA_DEFINE_S32_S8U8_RS_TN
#undef TL_WGMMA_DEFINE_S32_U8S8_RS_TN
#undef TL_WGMMA_DEFINE_S32_U8U8_RS_TN
#undef TL_WGMMA_DEFINE_F16_E4M3E4M3_RS_TN
#undef TL_WGMMA_DEFINE_F32_E4M3E4M3_RS_TN
#undef TL_WGMMA_DEFINE_F16_E4M3E5M2_RS_TN
#undef TL_WGMMA_DEFINE_F32_E4M3E5M2_RS_TN
#undef TL_WGMMA_DEFINE_F16_E5M2E4M3_RS_TN
#undef TL_WGMMA_DEFINE_F32_E5M2E4M3_RS_TN
#undef TL_WGMMA_DEFINE_F16_E5M2E5M2_RS_TN
#undef TL_WGMMA_DEFINE_F32_E5M2E5M2_RS_TN
#undef TL_WGMMA_FOREACH_N_FLOAT_MUL8
#undef TL_WGMMA_FOREACH_N_INT32_MUL8
#undef TL_WGMMA_DEFINE_SS_TN_FIXED_SCALE
#undef TL_WGMMA_DEFINE_SS_GENERAL
#undef TL_WGMMA_DEFINE_SS_TN
#undef TL_WGMMA_DEFINE_RS_TN_FIXED_SCALE
#undef TL_WGMMA_DEFINE_RS_GENERAL
#undef TL_WGMMA_DEFINE_RS_TN
// 函数模板委托给类模板
template
<
DataType
A_type
,
DataType
B_type
,
DataType
C_type
,
int
M
,
int
N
,
template
<
DataType
A_type
,
DataType
B_type
,
DataType
C_type
,
int
M
,
int
N
,
int
K
,
bool
tnspA
,
bool
tnspB
,
int
scaleA
=
1
,
int
scaleB
=
1
>
int
K
,
bool
tnspA
,
bool
tnspB
,
int
scaleA
=
1
,
int
scaleB
=
1
>
TL_DEVICE
void
wgmma_ss
(
uint64_t
desc_a
,
uint64_t
desc_b
,
uint32_t
*
c
,
TL_DEVICE
void
wgmma_ss
(
uint64_t
desc_a
,
uint64_t
desc_b
,
uint32_t
*
c
,
...
@@ -519,129 +460,12 @@ TL_DEVICE void wgmma_ss(uint64_t desc_a, uint64_t desc_b, uint32_t *c,
...
@@ -519,129 +460,12 @@ TL_DEVICE void wgmma_ss(uint64_t desc_a, uint64_t desc_b, uint32_t *c,
scaleB
>::
execute
(
desc_a
,
desc_b
,
c
,
scale_out
);
scaleB
>::
execute
(
desc_a
,
desc_b
,
c
,
scale_out
);
}
}
// ================================= Mixed Precision Support
template
<
DataType
A_type
,
DataType
B_type
,
DataType
C_type
,
int
M
,
int
N
,
// =================================
int
K
,
bool
tnspA
,
bool
tnspB
,
int
scaleA
=
1
,
int
scaleB
=
1
>
TL_DEVICE
void
wgmma_rs
(
const
uint32_t
*
a
,
uint64_t
desc_b
,
uint32_t
*
c
,
// Mixed precision: S8 x U8 -> S32
bool
scale_out
)
{
template
<
bool
tnspA
,
bool
tnspB
,
int
scaleA
,
int
scaleB
>
WgmmaRSImpl
<
A_type
,
B_type
,
C_type
,
M
,
N
,
K
,
tnspA
,
tnspB
,
scaleA
,
struct
WgmmaSSImpl
<
DataType
::
kInt8
,
DataType
::
kUInt8
,
DataType
::
kInt32
,
64
,
8
,
scaleB
>::
execute
(
a
,
desc_b
,
c
,
scale_out
);
32
,
tnspA
,
tnspB
,
scaleA
,
scaleB
>
{
}
TL_DEVICE
static
void
execute
(
uint64_t
desc_a
,
uint64_t
desc_b
,
uint32_t
*
c
,
bool
scale_out
)
{
asm
volatile
(
"{
\n
"
".reg .pred p;
\n
"
"setp.ne.b32 p, %4, 0;
\n
"
"wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.u8 "
"{%0, %1}, %2, %3, p, %5, %6, %7, %8;
\n
"
"}
\n
"
:
"+r"
(
c
[
0
]),
"+r"
(
c
[
1
])
:
"l"
(
desc_a
),
"l"
(
desc_b
),
"r"
(
int32_t
(
scale_out
)),
"n"
(
int32_t
(
scaleA
)),
"n"
(
int32_t
(
scaleB
)),
"n"
(
int32_t
(
tnspA
)),
"n"
(
int32_t
(
tnspB
)));
}
};
// Mixed precision: U8 x S8 -> S32
template
<
bool
tnspA
,
bool
tnspB
,
int
scaleA
,
int
scaleB
>
struct
WgmmaSSImpl
<
DataType
::
kUInt8
,
DataType
::
kInt8
,
DataType
::
kInt32
,
64
,
8
,
32
,
tnspA
,
tnspB
,
scaleA
,
scaleB
>
{
TL_DEVICE
static
void
execute
(
uint64_t
desc_a
,
uint64_t
desc_b
,
uint32_t
*
c
,
bool
scale_out
)
{
asm
volatile
(
"{
\n
"
".reg .pred p;
\n
"
"setp.ne.b32 p, %4, 0;
\n
"
"wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.s8 "
"{%0, %1}, %2, %3, p, %5, %6, %7, %8;
\n
"
"}
\n
"
:
"+r"
(
c
[
0
]),
"+r"
(
c
[
1
])
:
"l"
(
desc_a
),
"l"
(
desc_b
),
"r"
(
int32_t
(
scale_out
)),
"n"
(
int32_t
(
scaleA
)),
"n"
(
int32_t
(
scaleB
)),
"n"
(
int32_t
(
tnspA
)),
"n"
(
int32_t
(
tnspB
)));
}
};
// Mixed precision: U8 x U8 -> S32
template
<
bool
tnspA
,
bool
tnspB
,
int
scaleA
,
int
scaleB
>
struct
WgmmaSSImpl
<
DataType
::
kUInt8
,
DataType
::
kUInt8
,
DataType
::
kInt32
,
64
,
8
,
32
,
tnspA
,
tnspB
,
scaleA
,
scaleB
>
{
TL_DEVICE
static
void
execute
(
uint64_t
desc_a
,
uint64_t
desc_b
,
uint32_t
*
c
,
bool
scale_out
)
{
asm
volatile
(
"{
\n
"
".reg .pred p;
\n
"
"setp.ne.b32 p, %4, 0;
\n
"
"wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.u8 "
"{%0, %1}, %2, %3, p, %5, %6, %7, %8;
\n
"
"}
\n
"
:
"+r"
(
c
[
0
]),
"+r"
(
c
[
1
])
:
"l"
(
desc_a
),
"l"
(
desc_b
),
"r"
(
int32_t
(
scale_out
)),
"n"
(
int32_t
(
scaleA
)),
"n"
(
int32_t
(
scaleB
)),
"n"
(
int32_t
(
tnspA
)),
"n"
(
int32_t
(
tnspB
)));
}
};
// Mixed precision FP8: E4M3 x E5M2 -> F16
template
<
bool
tnspA
,
bool
tnspB
,
int
scaleA
,
int
scaleB
>
struct
WgmmaSSImpl
<
DataType
::
kFloat8_e4m3
,
DataType
::
kFloat8_e5m2
,
DataType
::
kFloat16
,
64
,
8
,
32
,
tnspA
,
tnspB
,
scaleA
,
scaleB
>
{
TL_DEVICE
static
void
execute
(
uint64_t
desc_a
,
uint64_t
desc_b
,
uint32_t
*
c
,
bool
scale_out
)
{
asm
volatile
(
"{
\n
"
".reg .pred p;
\n
"
"setp.ne.b32 p, %4, 0;
\n
"
"wgmma.mma_async.sync.aligned.m64n8k32.f16.e4m3.e5m2 "
"{%0, %1}, %2, %3, p, %5, %6, %7, %8;
\n
"
"}
\n
"
:
"+r"
(
c
[
0
]),
"+r"
(
c
[
1
])
:
"l"
(
desc_a
),
"l"
(
desc_b
),
"r"
(
int32_t
(
scale_out
)),
"n"
(
int32_t
(
scaleA
)),
"n"
(
int32_t
(
scaleB
)),
"n"
(
int32_t
(
tnspA
)),
"n"
(
int32_t
(
tnspB
)));
}
};
// Mixed precision FP8: E5M2 x E4M3 -> F16
template
<
bool
tnspA
,
bool
tnspB
,
int
scaleA
,
int
scaleB
>
struct
WgmmaSSImpl
<
DataType
::
kFloat8_e5m2
,
DataType
::
kFloat8_e4m3
,
DataType
::
kFloat16
,
64
,
8
,
32
,
tnspA
,
tnspB
,
scaleA
,
scaleB
>
{
TL_DEVICE
static
void
execute
(
uint64_t
desc_a
,
uint64_t
desc_b
,
uint32_t
*
c
,
bool
scale_out
)
{
asm
volatile
(
"{
\n
"
".reg .pred p;
\n
"
"setp.ne.b32 p, %4, 0;
\n
"
"wgmma.mma_async.sync.aligned.m64n8k32.f16.e5m2.e4m3 "
"{%0, %1}, %2, %3, p, %5, %6, %7, %8;
\n
"
"}
\n
"
:
"+r"
(
c
[
0
]),
"+r"
(
c
[
1
])
:
"l"
(
desc_a
),
"l"
(
desc_b
),
"r"
(
int32_t
(
scale_out
)),
"n"
(
int32_t
(
scaleA
)),
"n"
(
int32_t
(
scaleB
)),
"n"
(
int32_t
(
tnspA
)),
"n"
(
int32_t
(
tnspB
)));
}
};
// ================================= Convenience Templates
// =================================
// Type trait to determine the number of output registers needed
template
<
DataType
C_type
,
int
M
,
int
N
>
struct
WgmmaOutputRegs
{
static
constexpr
int
value
=
(
M
*
N
*
(
C_type
==
DataType
::
kFloat32
?
32
:
16
))
/
(
32
*
8
);
};
// Type trait to get element size in bits
template
<
DataType
dtype
>
struct
ElementBits
{
static
constexpr
int
value
=
(
dtype
==
DataType
::
kFloat32
||
dtype
==
DataType
::
kTensorFloat32
||
dtype
==
DataType
::
kInt32
)
?
32
:
(
dtype
==
DataType
::
kFloat16
||
dtype
==
DataType
::
kBFloat16
||
dtype
==
DataType
::
kInt16
||
dtype
==
DataType
::
kUInt16
)
?
16
:
(
dtype
==
DataType
::
kInt8
||
dtype
==
DataType
::
kUInt8
||
dtype
==
DataType
::
kFloat8_e4m3
||
dtype
==
DataType
::
kFloat8_e5m2
)
?
8
:
(
dtype
==
DataType
::
kInt4
||
dtype
==
DataType
::
kUInt4
)
?
4
:
8
;
};
}
// namespace tl
}
// namespace tl
\ No newline at end of file
src/tl_templates/cuda/intrin.h
View file @
bbbf4207
...
@@ -67,6 +67,20 @@ template <int NumMma> TL_DEVICE void warpgroup_wait() {
...
@@ -67,6 +67,20 @@ template <int NumMma> TL_DEVICE void warpgroup_wait() {
cute
::
warpgroup_wait
<
NumMma
>
();
cute
::
warpgroup_wait
<
NumMma
>
();
}
}
TL_DEVICE
void
warpgroup_fence_operand
(
uint32_t
*
regs
,
int
count
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
count
;
++
i
)
{
cute
::
warpgroup_fence_operand
(
regs
[
i
]);
}
}
TL_DEVICE
void
warpgroup_fence_operand
(
float
*
regs
,
int
count
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
count
;
++
i
)
{
cute
::
warpgroup_fence_operand
(
regs
[
i
]);
}
}
// Template parameter:
// Template parameter:
// thread_extent: the logical size (in number of threads) of each "group"
// thread_extent: the logical size (in number of threads) of each "group"
// within which we want to elect exactly ONE representative
// within which we want to elect exactly ONE representative
...
...
src/tl_templates/cuda/reduce.h
View file @
bbbf4207
#pragma once
#pragma once
#include "common.h"
#include "common.h"
#include <cstdint>
#include <type_traits>
namespace
tl
{
namespace
tl
{
// Select a wider accumulator type for improved numerical accuracy.
// Default: accumulate in the same type. Specialize FP16/BF16 to float.
template
<
typename
T
>
struct
AccType
{
using
type
=
T
;
};
template
<
>
struct
AccType
<
half_t
>
{
using
type
=
float
;
};
template
<
>
struct
AccType
<
bfloat16_t
>
{
using
type
=
float
;
};
struct
SumOp
{
struct
SumOp
{
template
<
typename
T
>
TL_DEVICE
T
operator
()(
T
const
&
x
,
T
const
&
y
)
{
template
<
typename
T
>
TL_DEVICE
T
operator
()(
T
const
&
x
,
T
const
&
y
)
{
return
x
+
y
;
return
x
+
y
;
...
@@ -40,53 +54,6 @@ struct BitXorOp {
...
@@ -40,53 +54,6 @@ struct BitXorOp {
}
}
};
};
template
<
class
Reducer
,
int
Threads
,
bool
UseAbs
,
bool
NeedAccumulate
>
struct
SharedReduceWarp
{
template
<
typename
T
>
static
TL_DEVICE
void
run
(
const
T
*
__restrict__
src
,
T
*
__restrict__
dst
,
int
total_dest
,
int
reduce_extent
,
int
tail
,
T
init_value
)
{
if
(
total_dest
<=
0
||
reduce_extent
<=
0
)
return
;
constexpr
int
kWarpSize
=
32
;
static_assert
(
Threads
%
kWarpSize
==
0
,
"SharedReduceWarp expects blockDim.x to be a multiple of "
"warp size on CUDA."
);
const
int
tid
=
threadIdx
.
x
;
const
int
warp_id
=
tid
/
kWarpSize
;
const
int
lane
=
tid
%
kWarpSize
;
const
int
num_warps
=
Threads
/
kWarpSize
;
for
(
int
dest_idx
=
warp_id
;
dest_idx
<
total_dest
;
dest_idx
+=
num_warps
)
{
const
int
prefix
=
tail
==
1
?
dest_idx
:
dest_idx
/
tail
;
const
int
suffix
=
tail
==
1
?
0
:
dest_idx
%
tail
;
const
int
src_base
=
(
prefix
*
reduce_extent
)
*
tail
+
suffix
;
const
int
dst_index
=
prefix
*
tail
+
suffix
;
T
partial
=
init_value
;
for
(
int
rv
=
lane
;
rv
<
reduce_extent
;
rv
+=
kWarpSize
)
{
T
val
=
src
[
src_base
+
rv
*
tail
];
if
constexpr
(
UseAbs
)
{
val
=
val
<
T
(
0
)
?
-
val
:
val
;
}
partial
=
Reducer
()(
partial
,
val
);
}
unsigned
mask
=
__activemask
();
for
(
int
offset
=
kWarpSize
/
2
;
offset
>
0
;
offset
>>=
1
)
{
T
other
=
__shfl_down_sync
(
mask
,
partial
,
offset
);
partial
=
Reducer
()(
partial
,
other
);
}
if
(
lane
==
0
)
{
if
constexpr
(
NeedAccumulate
)
{
partial
=
Reducer
()(
dst
[
dst_index
],
partial
);
}
dst
[
dst_index
]
=
partial
;
}
}
}
};
template
<
class
Reducer
,
int
threads
,
int
scale
,
int
thread_offset
=
0
,
template
<
class
Reducer
,
int
threads
,
int
scale
,
int
thread_offset
=
0
,
int
all_threads
=
threads
>
int
all_threads
=
threads
>
struct
AllReduce
{
struct
AllReduce
{
...
@@ -102,7 +69,7 @@ struct AllReduce {
...
@@ -102,7 +69,7 @@ struct AllReduce {
__syncthreads
();
__syncthreads
();
x
=
Reducer
()(
x
,
red_buf
[(
threadIdx
.
x
-
thread_offset
)
^
offset
]);
x
=
Reducer
()(
x
,
red_buf
[(
threadIdx
.
x
-
thread_offset
)
^
offset
]);
}
else
{
}
else
{
x
=
Reducer
()(
x
,
T
(
__
shfl_xor_sync
(
uint32_t
(
-
1
),
x
,
offset
))
)
;
x
=
Reducer
()(
x
,
tl
::
shfl_xor_sync
(
uint32_t
(
-
1
),
x
,
offset
));
}
}
if
constexpr
(
offset
==
scale
)
{
if
constexpr
(
offset
==
scale
)
{
return
x
;
return
x
;
...
@@ -122,7 +89,7 @@ struct AllReduce {
...
@@ -122,7 +89,7 @@ struct AllReduce {
asm
volatile
(
"bar.sync %0, %1;"
:
:
"r"
(
2
),
"r"
(
all_threads
));
asm
volatile
(
"bar.sync %0, %1;"
:
:
"r"
(
2
),
"r"
(
all_threads
));
x
=
Reducer
()(
x
,
red_buf
[(
threadIdx
.
x
-
thread_offset
)
^
offset
]);
x
=
Reducer
()(
x
,
red_buf
[(
threadIdx
.
x
-
thread_offset
)
^
offset
]);
}
else
{
}
else
{
x
=
Reducer
()(
x
,
T
(
__
shfl_xor_sync
(
uint32_t
(
-
1
),
x
,
offset
))
)
;
x
=
Reducer
()(
x
,
tl
::
shfl_xor_sync
(
uint32_t
(
-
1
),
x
,
offset
));
}
}
if
constexpr
(
offset
==
scale
)
{
if
constexpr
(
offset
==
scale
)
{
return
x
;
return
x
;
...
@@ -159,7 +126,7 @@ template <int threads, bool reverse = false> struct CumSum1D {
...
@@ -159,7 +126,7 @@ template <int threads, bool reverse = false> struct CumSum1D {
#pragma unroll
#pragma unroll
for
(
int
off
=
1
;
off
<
SEG
;
off
<<=
1
)
{
for
(
int
off
=
1
;
off
<
SEG
;
off
<<=
1
)
{
T
n
=
(
T
)
__
shfl_down_sync
(
MASK
,
val
,
off
);
T
n
=
(
T
)
tl
::
shfl_down_sync
(
MASK
,
val
,
off
);
if
(
lane
<
SEG
-
off
)
if
(
lane
<
SEG
-
off
)
val
+=
n
;
val
+=
n
;
}
}
...
@@ -234,7 +201,7 @@ template <int threads, int Axis = 0, bool reverse = false> struct CumSum2D {
...
@@ -234,7 +201,7 @@ template <int threads, int Axis = 0, bool reverse = false> struct CumSum2D {
#pragma unroll
#pragma unroll
for
(
int
off
=
1
;
off
<
SEG
;
off
<<=
1
)
{
for
(
int
off
=
1
;
off
<
SEG
;
off
<<=
1
)
{
T
n
=
(
T
)
__
shfl_down_sync
(
MASK
,
val
,
off
);
T
n
=
tl
::
shfl_down_sync
(
MASK
,
val
,
off
);
if
(
lane
<
SEG
-
off
)
if
(
lane
<
SEG
-
off
)
val
+=
n
;
val
+=
n
;
}
}
...
@@ -244,10 +211,10 @@ template <int threads, int Axis = 0, bool reverse = false> struct CumSum2D {
...
@@ -244,10 +211,10 @@ template <int threads, int Axis = 0, bool reverse = false> struct CumSum2D {
if
(
real_col
<
W
)
if
(
real_col
<
W
)
dst
[
real_row
*
W
+
real_col
]
=
val
;
dst
[
real_row
*
W
+
real_col
]
=
val
;
T
segSum
=
(
T
)
__
shfl_sync
(
MASK
,
val
,
(
T
)
0
);
T
segSum
=
tl
::
shfl_sync
(
MASK
,
val
,
0
);
if
(
lane
==
0
)
if
(
lane
==
0
)
carry
=
segSum
;
carry
=
segSum
;
carry
=
(
T
)
__
shfl_sync
(
MASK
,
carry
,
(
T
)
0
);
carry
=
tl
::
shfl_sync
(
MASK
,
carry
,
0
);
}
}
}
else
{
}
else
{
for
(
int
seg
=
0
;
seg
*
SEG
<
W
;
++
seg
)
{
for
(
int
seg
=
0
;
seg
*
SEG
<
W
;
++
seg
)
{
...
@@ -260,7 +227,7 @@ template <int threads, int Axis = 0, bool reverse = false> struct CumSum2D {
...
@@ -260,7 +227,7 @@ template <int threads, int Axis = 0, bool reverse = false> struct CumSum2D {
#pragma unroll
#pragma unroll
for
(
int
off
=
1
;
off
<
SEG
;
off
<<=
1
)
{
for
(
int
off
=
1
;
off
<
SEG
;
off
<<=
1
)
{
T
n
=
(
T
)
__
shfl_up_sync
(
MASK
,
val
,
off
);
T
n
=
tl
::
shfl_up_sync
(
MASK
,
val
,
off
);
if
(
lane
>=
off
)
if
(
lane
>=
off
)
val
+=
n
;
val
+=
n
;
}
}
...
@@ -270,10 +237,10 @@ template <int threads, int Axis = 0, bool reverse = false> struct CumSum2D {
...
@@ -270,10 +237,10 @@ template <int threads, int Axis = 0, bool reverse = false> struct CumSum2D {
if
(
real_col
<
W
)
if
(
real_col
<
W
)
dst
[
real_row
*
W
+
real_col
]
=
val
;
dst
[
real_row
*
W
+
real_col
]
=
val
;
T
segSum
=
(
T
)
__
shfl_sync
(
MASK
,
val
,
SEG
-
1
);
T
segSum
=
tl
::
shfl_sync
(
MASK
,
val
,
SEG
-
1
);
if
(
lane
==
SEG
-
1
)
if
(
lane
==
SEG
-
1
)
carry
=
segSum
;
carry
=
segSum
;
carry
=
(
T
)
__
shfl_sync
(
MASK
,
carry
,
SEG
-
1
);
carry
=
tl
::
shfl_sync
(
MASK
,
carry
,
SEG
-
1
);
}
}
}
}
}
}
...
...
src/tl_templates/cuda/tcgen_05.h
View file @
bbbf4207
...
@@ -6,6 +6,7 @@
...
@@ -6,6 +6,7 @@
#endif
#endif
#include "common.h"
#include "common.h"
#include <cute/arch/cluster_sm90.hpp>
namespace
tl
{
namespace
tl
{
...
@@ -59,12 +60,15 @@ inline void __device__ amma_fp16bf16_ss(uint64_t const desc_a,
...
@@ -59,12 +60,15 @@ inline void __device__ amma_fp16bf16_ss(uint64_t const desc_a,
"r"
(
mask
[
0
]),
"r"
(
mask
[
1
]),
"r"
(
mask
[
2
]),
"r"
(
mask
[
3
]));
"r"
(
mask
[
0
]),
"r"
(
mask
[
1
]),
"r"
(
mask
[
2
]),
"r"
(
mask
[
3
]));
}
}
inline
__device__
void
amma_commit
(
uint64_t
const
*
smem_ptr
)
{
// Wrapper for CUTLASS umma_arrive: elect one lane, then arrive the mbarrier
TL_DEVICE
void
tcgen05_mma_arrive
(
void
const
*
smem_ptr
)
{
uint32_t
bar_intptr
=
smem_ptr_to_uint
(
smem_ptr
);
uint32_t
bar_intptr
=
smem_ptr_to_uint
(
smem_ptr
);
asm
volatile
(
"tcgen05.commit.cta_group::1.mbarrier::arrive::one.shared::"
if
(
cute
::
elect_one_sync
())
{
"cluster.b64 [%0];"
asm
volatile
(
"tcgen05.commit.cta_group::1.mbarrier::arrive::one.shared::"
:
"cluster.b64 [%0];"
:
"r"
(
bar_intptr
));
:
:
"r"
(
bar_intptr
));
}
}
}
}
// namespace tl
}
// namespace tl
\ No newline at end of file
src/transform/align_dynamic_shared_memory_allocations.cc
View file @
bbbf4207
...
@@ -47,7 +47,7 @@ public:
...
@@ -47,7 +47,7 @@ public:
}
}
Stmt
VisitStmt_
(
const
BlockNode
*
op
)
final
{
Stmt
VisitStmt_
(
const
BlockNode
*
op
)
final
{
Block
block
=
GetRef
<
Block
>
(
op
);
Block
block
=
tvm
::
ffi
::
GetRef
<
Block
>
(
op
);
Array
<
Buffer
>
alloc_buffers
=
op
->
alloc_buffers
;
Array
<
Buffer
>
alloc_buffers
=
op
->
alloc_buffers
;
alloc_buffers
.
MutateByApply
([
this
](
Buffer
buf
)
{
alloc_buffers
.
MutateByApply
([
this
](
Buffer
buf
)
{
auto
storage_scope
=
auto
storage_scope
=
...
@@ -58,7 +58,7 @@ public:
...
@@ -58,7 +58,7 @@ public:
buf
->
dtype
.
bytes
());
buf
->
dtype
.
bytes
());
if
(
!
new_shape
.
same_as
(
buf
->
shape
))
{
if
(
!
new_shape
.
same_as
(
buf
->
shape
))
{
ObjectPtr
<
BufferNode
>
new_buffer
=
ObjectPtr
<
BufferNode
>
new_buffer
=
make_object
<
BufferNode
>
(
*
(
buf
.
get
()));
tvm
::
ffi
::
make_object
<
BufferNode
>
(
*
(
buf
.
get
()));
new_buffer
->
shape
=
std
::
move
(
new_shape
);
new_buffer
->
shape
=
std
::
move
(
new_shape
);
buffer_remap_
.
Set
(
buf
,
Buffer
(
new_buffer
));
buffer_remap_
.
Set
(
buf
,
Buffer
(
new_buffer
));
return
Buffer
(
new_buffer
);
return
Buffer
(
new_buffer
);
...
@@ -73,7 +73,7 @@ public:
...
@@ -73,7 +73,7 @@ public:
}
}
Stmt
VisitStmt_
(
const
BufferStoreNode
*
op
)
final
{
Stmt
VisitStmt_
(
const
BufferStoreNode
*
op
)
final
{
auto
store_node
=
GetRef
<
BufferStore
>
(
op
);
auto
store_node
=
tvm
::
ffi
::
GetRef
<
BufferStore
>
(
op
);
Buffer
buf
=
op
->
buffer
;
Buffer
buf
=
op
->
buffer
;
if
(
buffer_remap_
.
count
(
buf
))
{
if
(
buffer_remap_
.
count
(
buf
))
{
buf
=
buffer_remap_
[
buf
];
buf
=
buffer_remap_
[
buf
];
...
@@ -83,7 +83,7 @@ public:
...
@@ -83,7 +83,7 @@ public:
}
}
PrimExpr
VisitExpr_
(
const
BufferLoadNode
*
op
)
final
{
PrimExpr
VisitExpr_
(
const
BufferLoadNode
*
op
)
final
{
auto
load_node
=
GetRef
<
BufferLoad
>
(
op
);
auto
load_node
=
tvm
::
ffi
::
GetRef
<
BufferLoad
>
(
op
);
Buffer
buf
=
op
->
buffer
;
Buffer
buf
=
op
->
buffer
;
if
(
buffer_remap_
.
count
(
buf
))
{
if
(
buffer_remap_
.
count
(
buf
))
{
buf
=
buffer_remap_
[
buf
];
buf
=
buffer_remap_
[
buf
];
...
@@ -149,11 +149,11 @@ tvm::transform::Pass AlignDynamicSharedMemoryAllocations(int align_bytes) {
...
@@ -149,11 +149,11 @@ tvm::transform::Pass AlignDynamicSharedMemoryAllocations(int align_bytes) {
"tl.AlignDynamicSharedMemoryAllocations"
,
{});
"tl.AlignDynamicSharedMemoryAllocations"
,
{});
}
}
TVM_FFI_STATIC_INIT_BLOCK
({
TVM_FFI_STATIC_INIT_BLOCK
(
)
{
namespace
refl
=
tvm
::
ffi
::
reflection
;
namespace
refl
=
tvm
::
ffi
::
reflection
;
refl
::
GlobalDef
().
def
(
"tl.transform.AlignDynamicSharedMemoryAllocations"
,
refl
::
GlobalDef
().
def
(
"tl.transform.AlignDynamicSharedMemoryAllocations"
,
AlignDynamicSharedMemoryAllocations
);
AlignDynamicSharedMemoryAllocations
);
}
);
}
}
// namespace tl
}
// namespace tl
}
// namespace tvm
}
// namespace tvm
src/transform/annotate_device_regions.cc
View file @
bbbf4207
...
@@ -46,13 +46,13 @@ public:
...
@@ -46,13 +46,13 @@ public:
Stmt
VisitStmt_
(
const
AttrStmtNode
*
op
)
final
{
Stmt
VisitStmt_
(
const
AttrStmtNode
*
op
)
final
{
if
(
op
->
attr_key
==
tvm
::
attr
::
kTarget
)
{
if
(
op
->
attr_key
==
tvm
::
attr
::
kTarget
)
{
// If a target attribute already exists, use it as-is.
// If a target attribute already exists, use it as-is.
return
GetRef
<
Stmt
>
(
op
);
return
tvm
::
ffi
::
GetRef
<
Stmt
>
(
op
);
}
else
if
(
op
->
attr_key
==
tir
::
attr
::
thread_extent
||
}
else
if
(
op
->
attr_key
==
tir
::
attr
::
thread_extent
||
op
->
attr_key
==
tir
::
attr
::
pipeline_exec_scope
||
op
->
attr_key
==
tir
::
attr
::
pipeline_exec_scope
||
op
->
attr_key
==
tir
::
attr
::
device_scope
)
{
op
->
attr_key
==
tir
::
attr
::
device_scope
)
{
// These attributes are only allowed in device-side code, so
// These attributes are only allowed in device-side code, so
// they should be annotated with the function's default target.
// they should be annotated with the function's default target.
Stmt
body
=
GetRef
<
Stmt
>
(
op
);
Stmt
body
=
tvm
::
ffi
::
GetRef
<
Stmt
>
(
op
);
return
AttrStmt
(
device_target_
,
tvm
::
attr
::
kTarget
,
0
,
body
);
return
AttrStmt
(
device_target_
,
tvm
::
attr
::
kTarget
,
0
,
body
);
}
else
{
}
else
{
// All other annotations are ignored
// All other annotations are ignored
...
@@ -90,11 +90,11 @@ tvm::transform::Pass AnnotateDeviceRegions() {
...
@@ -90,11 +90,11 @@ tvm::transform::Pass AnnotateDeviceRegions() {
return
CreatePrimFuncPass
(
pass_func
,
0
,
"tl.AnnotateDeviceRegions"
,
{});
return
CreatePrimFuncPass
(
pass_func
,
0
,
"tl.AnnotateDeviceRegions"
,
{});
}
}
TVM_FFI_STATIC_INIT_BLOCK
({
TVM_FFI_STATIC_INIT_BLOCK
(
)
{
namespace
refl
=
tvm
::
ffi
::
reflection
;
namespace
refl
=
tvm
::
ffi
::
reflection
;
refl
::
GlobalDef
().
def
(
"tl.transform.AnnotateDeviceRegions"
,
refl
::
GlobalDef
().
def
(
"tl.transform.AnnotateDeviceRegions"
,
AnnotateDeviceRegions
);
AnnotateDeviceRegions
);
}
);
}
}
// namespace tl
}
// namespace tl
}
// namespace tvm
}
// namespace tvm
src/transform/annotate_warp_group_reg_alloc.cc
View file @
bbbf4207
...
@@ -124,7 +124,9 @@ private:
...
@@ -124,7 +124,9 @@ private:
}
}
auto
producer_body
=
if_then_else
->
then_case
;
auto
producer_body
=
if_then_else
->
then_case
;
Optional
<
Stmt
>
consumer_body
=
if_then_else
->
else_case
;
Optional
<
Stmt
>
consumer_body
=
if_then_else
->
else_case
;
ICHECK
(
consumer_body
.
defined
())
<<
"Consumer body is undefined"
;
// In some degenerate warp-specialized patterns (e.g., producer-only),
// the consumer body may be absent. Handle gracefully by only annotating
// the producer side when consumer is missing.
auto
dec_reg
=
nreg_
[
0
].
as
<
IntImmNode
>
()
->
value
;
auto
dec_reg
=
nreg_
[
0
].
as
<
IntImmNode
>
()
->
value
;
auto
inc_reg
=
nreg_
[
1
].
as
<
IntImmNode
>
()
->
value
;
auto
inc_reg
=
nreg_
[
1
].
as
<
IntImmNode
>
()
->
value
;
...
@@ -150,15 +152,20 @@ private:
...
@@ -150,15 +152,20 @@ private:
producer_stmts
.
push_back
(
producer_body
);
producer_stmts
.
push_back
(
producer_body
);
auto
new_producer_body
=
SeqStmt
(
producer_stmts
);
auto
new_producer_body
=
SeqStmt
(
producer_stmts
);
Array
<
Stmt
>
consumer_stmts
;
Stmt
new_if_stmt
;
consumer_stmts
.
push_back
(
inc_reg_stmt
);
if
(
consumer_body
.
defined
())
{
consumer_stmts
.
push_back
(
consumer_body
.
value
());
Array
<
Stmt
>
consumer_stmts
;
auto
new_consumer_body
=
SeqStmt
(
consumer_stmts
);
consumer_stmts
.
push_back
(
inc_reg_stmt
);
consumer_stmts
.
push_back
(
consumer_body
.
value
());
auto
new_consumer_body
=
SeqStmt
(
consumer_stmts
);
new_if_stmt
=
IfThenElse
(
if_then_else
->
condition
,
new_producer_body
,
new_consumer_body
);
}
else
{
// No consumer branch; keep the if-then form.
new_if_stmt
=
IfThenElse
(
if_then_else
->
condition
,
new_producer_body
);
}
auto
new_if_stmt
=
IfThenElse
(
if_then_else
->
condition
,
new_producer_body
,
new_consumer_body
);
auto
new_attr
=
AttrStmt
(
op
->
node
,
op
->
attr_key
,
op
->
value
,
new_if_stmt
);
auto
new_attr
=
AttrStmt
(
op
->
node
,
op
->
attr_key
,
op
->
value
,
new_if_stmt
);
return
new_attr
;
return
new_attr
;
}
else
{
}
else
{
return
StmtExprMutator
::
VisitStmt_
(
op
);
return
StmtExprMutator
::
VisitStmt_
(
op
);
...
@@ -181,11 +188,11 @@ tvm::transform::Pass AnnotateWarpGroupRegAlloc() {
...
@@ -181,11 +188,11 @@ tvm::transform::Pass AnnotateWarpGroupRegAlloc() {
return
CreatePrimFuncPass
(
pass_func
,
0
,
"tl.AnnotateWarpGroupRegAlloc"
,
{});
return
CreatePrimFuncPass
(
pass_func
,
0
,
"tl.AnnotateWarpGroupRegAlloc"
,
{});
}
}
TVM_FFI_STATIC_INIT_BLOCK
({
TVM_FFI_STATIC_INIT_BLOCK
(
)
{
namespace
refl
=
tvm
::
ffi
::
reflection
;
namespace
refl
=
tvm
::
ffi
::
reflection
;
refl
::
GlobalDef
().
def
(
"tl.transform.AnnotateWarpGroupRegAlloc"
,
refl
::
GlobalDef
().
def
(
"tl.transform.AnnotateWarpGroupRegAlloc"
,
AnnotateWarpGroupRegAlloc
);
AnnotateWarpGroupRegAlloc
);
}
);
}
}
// namespace tl
}
// namespace tl
}
// namespace tvm
}
// namespace tvm
src/transform/arg_binder.cc
View file @
bbbf4207
...
@@ -80,8 +80,8 @@ void ArgBinder::Bind(const PrimExpr &arg, const PrimExpr &value,
...
@@ -80,8 +80,8 @@ void ArgBinder::Bind(const PrimExpr &arg, const PrimExpr &value,
Bind_
(
arg
,
value
,
arg_name
,
with_let
);
Bind_
(
arg
,
value
,
arg_name
,
with_let
);
}
}
void
ArgBinder
::
BindArray
(
const
Array
<
PrimExpr
>
&
arg
,
void
ArgBinder
::
BindArray
(
const
ffi
::
Array
<
PrimExpr
>
&
arg
,
const
Array
<
PrimExpr
>
&
value
,
const
ffi
::
Array
<
PrimExpr
>
&
value
,
const
std
::
string
&
arg_name
)
{
const
std
::
string
&
arg_name
)
{
ICHECK_EQ
(
arg
.
size
(),
value
.
size
())
ICHECK_EQ
(
arg
.
size
(),
value
.
size
())
<<
"Argument "
<<
arg_name
<<
" array size mismatch"
;
<<
"Argument "
<<
arg_name
<<
" array size mismatch"
;
...
@@ -250,7 +250,7 @@ void ArgBinder::BindDLTensor(const Buffer &buffer, const PrimExpr &device_type,
...
@@ -250,7 +250,7 @@ void ArgBinder::BindDLTensor(const Buffer &buffer, const PrimExpr &device_type,
// Assert the buffer is compact
// Assert the buffer is compact
DataType
stype
=
buffer
->
DefaultIndexType
();
DataType
stype
=
buffer
->
DefaultIndexType
();
PrimExpr
expect_stride
=
make_const
(
stype
,
1
);
PrimExpr
expect_stride
=
make_const
(
stype
,
1
);
Array
<
PrimExpr
>
conds
;
ffi
::
Array
<
PrimExpr
>
conds
;
for
(
size_t
i
=
buffer
->
shape
.
size
();
i
!=
0
;
--
i
)
{
for
(
size_t
i
=
buffer
->
shape
.
size
();
i
!=
0
;
--
i
)
{
size_t
k
=
i
-
1
;
size_t
k
=
i
-
1
;
PrimExpr
svalue
=
PrimExpr
svalue
=
...
...
src/transform/arg_binder.h
View file @
bbbf4207
...
@@ -82,7 +82,8 @@ public:
...
@@ -82,7 +82,8 @@ public:
* \param value The target expression value
* \param value The target expression value
* \param arg_name argument name.
* \param arg_name argument name.
*/
*/
void
BindArray
(
const
Array
<
PrimExpr
>
&
arg
,
const
Array
<
PrimExpr
>
&
value
,
void
BindArray
(
const
ffi
::
Array
<
PrimExpr
>
&
arg
,
const
ffi
::
Array
<
PrimExpr
>
&
value
,
const
std
::
string
&
arg_name
);
const
std
::
string
&
arg_name
);
/*!
/*!
* \brief Bind symbolic buffer to another symbolic buffer
* \brief Bind symbolic buffer to another symbolic buffer
...
@@ -149,7 +150,7 @@ public:
...
@@ -149,7 +150,7 @@ public:
*/
*/
const
std
::
vector
<
Stmt
>
&
init_nest
()
const
{
return
init_nest_
;
}
const
std
::
vector
<
Stmt
>
&
init_nest
()
const
{
return
init_nest_
;
}
/*! \return Handle data type of the data */
/*! \return Handle data type of the data */
const
Map
<
Var
,
PrimExpr
>
&
def_handle_dtype
()
const
{
const
ffi
::
Map
<
Var
,
PrimExpr
>
&
def_handle_dtype
()
const
{
return
def_handle_dtype_
;
return
def_handle_dtype_
;
}
}
...
@@ -164,7 +165,7 @@ private:
...
@@ -164,7 +165,7 @@ private:
/*! \brief Initialize nest */
/*! \brief Initialize nest */
std
::
vector
<
Stmt
>
init_nest_
;
std
::
vector
<
Stmt
>
init_nest_
;
/*! \brief handle data type in the defintiions */
/*! \brief handle data type in the defintiions */
Map
<
Var
,
PrimExpr
>
def_handle_dtype_
;
ffi
::
Map
<
Var
,
PrimExpr
>
def_handle_dtype_
;
/*! \brief asserts generated */
/*! \brief asserts generated */
std
::
vector
<
Stmt
>
asserts_
;
std
::
vector
<
Stmt
>
asserts_
;
/*! \brief internal analyzer. */
/*! \brief internal analyzer. */
...
...
Prev
1
…
3
4
5
6
7
8
9
10
11
…
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment