Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
8f4628e0
Commit
8f4628e0
authored
Nov 11, 2025
by
qisan
Browse files
[Bugfix] Using a new data layout and the performance of NN gemm exceeds rocblas
parent
9a640856
Changes
16
Show whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
124 additions
and
158 deletions
+124
-158
examples/gemm/example_gemm_intrinsics_dcu.py
examples/gemm/example_gemm_intrinsics_dcu.py
+4
-5
src/layout/gemm_layouts.cc
src/layout/gemm_layouts.cc
+5
-4
src/layout/layout.h
src/layout/layout.h
+2
-2
src/op/gemm.cc
src/op/gemm.cc
+5
-6
src/target/intrin_rule_hip.cc
src/target/intrin_rule_hip.cc
+0
-1
src/target/utils.cc
src/target/utils.cc
+0
-1
src/tl_templates/dcu_hip/common.h
src/tl_templates/dcu_hip/common.h
+18
-24
src/tl_templates/dcu_hip/copy.h
src/tl_templates/dcu_hip/copy.h
+0
-1
src/tl_templates/dcu_hip/core.hpp
src/tl_templates/dcu_hip/core.hpp
+28
-57
src/tl_templates/dcu_hip/debug.h
src/tl_templates/dcu_hip/debug.h
+0
-1
src/tl_templates/dcu_hip/gemm.h
src/tl_templates/dcu_hip/gemm.h
+15
-16
src/tl_templates/dcu_hip/hip_fp8.h
src/tl_templates/dcu_hip/hip_fp8.h
+0
-1
src/tl_templates/dcu_hip/reduce.h
src/tl_templates/dcu_hip/reduce.h
+13
-11
src/tl_templates/dcu_hip/threadblock_swizzle.h
src/tl_templates/dcu_hip/threadblock_swizzle.h
+0
-1
testing/python/dcu/test_tilelang_gemm_mmac_intrinsic.py
testing/python/dcu/test_tilelang_gemm_mmac_intrinsic.py
+3
-2
tilelang/intrinsics/mmac_macro_generator.py
tilelang/intrinsics/mmac_macro_generator.py
+31
-25
No files found.
examples/gemm/example_gemm_intrinsics_dcu.py
View file @
8f4628e0
...
@@ -10,6 +10,7 @@ from tilelang import disable_cache
...
@@ -10,6 +10,7 @@ from tilelang import disable_cache
disable_cache
()
disable_cache
()
def
make_swizzle_layout
(
shared_buf
):
def
make_swizzle_layout
(
shared_buf
):
dtype
=
shared_buf
.
dtype
dtype
=
shared_buf
.
dtype
shape
=
shared_buf
.
shape
shape
=
shared_buf
.
shape
...
@@ -186,5 +187,3 @@ def main():
...
@@ -186,5 +187,3 @@ def main():
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
main
()
main
()
src/layout/gemm_layouts.cc
View file @
8f4628e0
...
@@ -169,7 +169,7 @@ Fragment makeGemmFragmentCDCU(const int block_m, const int block_n,
...
@@ -169,7 +169,7 @@ Fragment makeGemmFragmentCDCU(const int block_m, const int block_n,
auto
warp_layout
=
auto
warp_layout
=
base_layout
->
Repeat
({
warp_m
/
16
,
warp_n
/
16
},
false
,
false
);
base_layout
->
Repeat
({
warp_m
/
16
,
warp_n
/
16
},
false
,
false
);
auto
block_layout
=
auto
block_layout
=
warp_layout
->
Repeat
({
block_m
/
warp_m
,
block_n
/
warp_n
},
true
,
tru
e
);
warp_layout
->
Repeat
({
block_m
/
warp_m
,
block_n
/
warp_n
},
true
,
fals
e
);
return
block_layout
;
return
block_layout
;
}
}
...
@@ -747,7 +747,8 @@ Layout makeGemmABLayout(int mat_stride, int mat_continuous, int continuity,
...
@@ -747,7 +747,8 @@ Layout makeGemmABLayout(int mat_stride, int mat_continuous, int continuity,
if
(
!
k_inner
&&
element_size
==
8
)
// int8 KxN
if
(
!
k_inner
&&
element_size
==
8
)
// int8 KxN
return
makeGemmABLayoutPadded
(
mat_stride
,
mat_continuous
,
element_size
);
return
makeGemmABLayoutPadded
(
mat_stride
,
mat_continuous
,
element_size
);
else
if
(
mat_continuous
%
(
vector_size
*
8
)
==
0
)
else
if
(
mat_continuous
%
(
vector_size
*
8
)
==
0
)
// return makeHalfBankSwizzleLayout(mat_stride, mat_continuous, element_size);
// return makeHalfBankSwizzleLayout(mat_stride, mat_continuous,
// element_size);
return
makeFullBankSwizzleLayout
(
mat_stride
,
mat_continuous
,
element_size
);
return
makeFullBankSwizzleLayout
(
mat_stride
,
mat_continuous
,
element_size
);
else
if
(
mat_continuous
%
(
vector_size
*
4
)
==
0
)
else
if
(
mat_continuous
%
(
vector_size
*
4
)
==
0
)
return
makeHalfBankSwizzleLayout
(
mat_stride
,
mat_continuous
,
element_size
);
return
makeHalfBankSwizzleLayout
(
mat_stride
,
mat_continuous
,
element_size
);
...
...
src/layout/layout.h
View file @
8f4628e0
src/op/gemm.cc
View file @
8f4628e0
...
@@ -4,8 +4,8 @@
...
@@ -4,8 +4,8 @@
*/
*/
#include "gemm.h"
#include "gemm.h"
#include <fstream>
#include "builtin.h"
#include "builtin.h"
#include <fstream>
#include <tvm/tir/builtin.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>
#include <tvm/tir/op_attr_types.h>
...
@@ -828,8 +828,7 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T,
...
@@ -828,8 +828,7 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T,
ICHECK
(
C
.
scope
()
==
"local.fragment"
)
ICHECK
(
C
.
scope
()
==
"local.fragment"
)
<<
"CDNA gemm (FMMA) only supports C in local.fragment scope, got "
<<
"CDNA gemm (FMMA) only supports C in local.fragment scope, got "
<<
C
.
scope
();
<<
C
.
scope
();
if
(
TargetIsDCU
(
T
.
target
))
if
(
TargetIsDCU
(
T
.
target
))
{
{
auto
fragment
=
auto
fragment
=
makeGemmFragmentCDCU
(
M
,
N
,
M
/
warp_m
,
N
/
warp_n
,
C
->
dtype
.
bits
());
makeGemmFragmentCDCU
(
M
,
N
,
M
/
warp_m
,
N
/
warp_n
,
C
->
dtype
.
bits
());
results
.
Set
(
C
,
fragment
->
BindThreadRange
(
thread_range
));
results
.
Set
(
C
,
fragment
->
BindThreadRange
(
thread_range
));
...
...
src/target/intrin_rule_hip.cc
View file @
8f4628e0
...
@@ -249,7 +249,6 @@ TVM_REGISTER_OP("tir.hip.__shfl")
...
@@ -249,7 +249,6 @@ TVM_REGISTER_OP("tir.hip.__shfl")
.
set_attr
<
TCallEffectKind
>
(
"TCallEffectKind"
,
.
set_attr
<
TCallEffectKind
>
(
"TCallEffectKind"
,
Integer
(
CallEffectKind
::
kOpaque
));
Integer
(
CallEffectKind
::
kOpaque
));
TVM_REGISTER_OP
(
"tir.hip.__shfl_sync"
)
TVM_REGISTER_OP
(
"tir.hip.__shfl_sync"
)
.
set_num_inputs
(
4
)
.
set_num_inputs
(
4
)
.
add_argument
(
"mask"
,
"Expr"
,
"The thread mask."
)
.
add_argument
(
"mask"
,
"Expr"
,
"The thread mask."
)
...
...
src/target/utils.cc
View file @
8f4628e0
...
@@ -5,7 +5,6 @@
...
@@ -5,7 +5,6 @@
#include "utils.h"
#include "utils.h"
namespace
tvm
{
namespace
tvm
{
namespace
tl
{
namespace
tl
{
...
...
src/tl_templates/dcu_hip/common.h
View file @
8f4628e0
#pragma once
#pragma once
#include "core.hpp"
#include "core.hpp"
#include <hip/hip_bf16.h>
#include <hip/hip_bf16.h>
#include <hip/hip_fp16.h>
#include <hip/hip_fp16.h>
...
@@ -106,22 +105,19 @@ TL_DEVICE unsigned __pack_bfloat162(const bfloat16_t x, const bfloat16_t y) {
...
@@ -106,22 +105,19 @@ TL_DEVICE unsigned __pack_bfloat162(const bfloat16_t x, const bfloat16_t y) {
return
(
v1
<<
16
)
|
v0
;
return
(
v1
<<
16
)
|
v0
;
}
}
template
<
typename
T
>
template
<
typename
T
>
struct
is_half_type
:
std
::
false_type
{};
struct
is_half_type
:
std
::
false_type
{};
template
<
>
template
<
>
struct
is_half_type
<
__half
>
:
std
::
true_type
{};
struct
is_half_type
<
__half
>
:
std
::
true_type
{};
template
<
>
template
<
>
struct
is_half_type
<
half_t
>
:
std
::
true_type
{};
struct
is_half_type
<
half_t
>
:
std
::
true_type
{};
template
<
typename
T
>
template
<
typename
T
>
inline
constexpr
bool
is_half_v
=
is_half_type
<
std
::
decay_t
<
T
>>::
value
;
inline
constexpr
bool
is_half_v
=
is_half_type
<
std
::
decay_t
<
T
>>::
value
;
template
<
typename
T1
,
typename
T2
>
template
<
typename
T1
,
typename
T2
>
TL_DEVICE
void
AtomicAdd
(
T1
*
address
,
T2
val
)
{
TL_DEVICE
void
AtomicAdd
(
T1
*
address
,
T2
val
)
{
if
constexpr
(
is_half_v
<
T1
>
)
{
if
constexpr
(
is_half_v
<
T1
>
)
{
__half
*
addr
=
reinterpret_cast
<
__half
*>
(
address
);
__half
*
addr
=
reinterpret_cast
<
__half
*>
(
address
);
__half
hval
=
__float2half
(
static_cast
<
float
>
(
val
));
__half
hval
=
__float2half
(
static_cast
<
float
>
(
val
));
atomicAdd
(
addr
,
hval
);
atomicAdd
(
addr
,
hval
);
}
else
{
}
else
{
...
@@ -129,16 +125,14 @@ TL_DEVICE void AtomicAdd(T1* address, T2 val) {
...
@@ -129,16 +125,14 @@ TL_DEVICE void AtomicAdd(T1* address, T2 val) {
}
}
}
}
template
<
typename
T1
,
typename
T2
>
template
<
typename
T1
,
typename
T2
>
TL_DEVICE
void
AtomicAdd
(
T1
&
ref
,
T2
val
)
{
TL_DEVICE
void
AtomicAdd
(
T1
&
ref
,
T2
val
)
{
AtomicAdd
(
&
ref
,
val
);
AtomicAdd
(
&
ref
,
val
);
}
}
template
<
typename
T1
,
typename
T2
>
TL_DEVICE
T1
AtomicAddRet
(
T1
&
ref
,
T2
val
)
{
template
<
typename
T1
,
typename
T2
>
TL_DEVICE
T1
AtomicAddRet
(
T1
&
ref
,
T2
val
)
{
return
atomicAdd
(
&
ref
,
static_cast
<
T1
>
(
val
));
return
atomicAdd
(
&
ref
,
static_cast
<
T1
>
(
val
));
}
}
template
<
typename
T
>
template
<
typename
T
>
TL_DEVICE
void
AtomicAddx4
(
T
*
ref
,
const
T
val
[
4
])
{
TL_DEVICE
void
AtomicAddx4
(
T
*
ref
,
const
T
val
[
4
])
{
atomicAdd
(
&
ref
[
0
],
val
[
0
]);
atomicAdd
(
&
ref
[
0
],
val
[
0
]);
atomicAdd
(
&
ref
[
1
],
val
[
1
]);
atomicAdd
(
&
ref
[
1
],
val
[
1
]);
atomicAdd
(
&
ref
[
2
],
val
[
2
]);
atomicAdd
(
&
ref
[
2
],
val
[
2
]);
...
...
src/tl_templates/dcu_hip/copy.h
View file @
8f4628e0
...
@@ -108,4 +108,3 @@ TL_DEVICE void cp_async_gs_conditional(void *lds_base_ptr,
...
@@ -108,4 +108,3 @@ TL_DEVICE void cp_async_gs_conditional(void *lds_base_ptr,
}
}
}
// namespace tl
}
// namespace tl
src/tl_templates/dcu_hip/core.hpp
View file @
8f4628e0
...
@@ -25,82 +25,53 @@
...
@@ -25,82 +25,53 @@
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x31004000
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x31004000
#endif
#endif
namespace
ck_tile
{
namespace
ck_tile
{
using
int32x4_t
=
int32_t
__attribute__
((
ext_vector_type
(
4
)));
using
int32x4_t
=
int32_t
__attribute__
((
ext_vector_type
(
4
)));
template
<
typename
T
>
template
<
typename
T
>
CK_TILE_HOST_DEVICE
constexpr
T
max
(
T
x
)
{
return
x
;
}
CK_TILE_HOST_DEVICE
constexpr
T
max
(
T
x
)
{
return
x
;
}
template
<
typename
T
>
template
<
typename
T
>
CK_TILE_HOST
constexpr
T
max
(
T
x
,
T
y
)
{
CK_TILE_HOST
constexpr
T
max
(
T
x
,
T
y
)
{
return
x
>
y
?
x
:
y
;
return
x
>
y
?
x
:
y
;
}
}
template
<
typename
T
>
template
<
typename
T
>
CK_TILE_DEVICE
constexpr
T
max
(
T
x
,
T
y
)
{
CK_TILE_DEVICE
constexpr
T
max
(
T
x
,
T
y
)
{
return
x
>
y
?
x
:
y
;
return
x
>
y
?
x
:
y
;
}
}
template
<
>
template
<
>
CK_TILE_DEVICE
float
max
(
float
x
,
float
y
)
{
CK_TILE_DEVICE
float
max
(
float
x
,
float
y
)
{
return
__builtin_fmaxf
(
x
,
y
);
// can resultin v_max3_f32
return
__builtin_fmaxf
(
x
,
y
);
// can resultin v_max3_f32
}
}
template
<
>
template
<
>
CK_TILE_DEVICE
double
max
(
double
x
,
double
y
)
{
CK_TILE_DEVICE
double
max
(
double
x
,
double
y
)
{
return
__builtin_fmax
(
x
,
y
);
// maybe still v_max3_f32
return
__builtin_fmax
(
x
,
y
);
// maybe still v_max3_f32
}
}
template
<
typename
X
,
typename
...
Ys
>
template
<
typename
X
,
typename
...
Ys
>
CK_TILE_HOST_DEVICE
constexpr
auto
max
(
X
x
,
Ys
...
ys
)
CK_TILE_HOST_DEVICE
constexpr
auto
max
(
X
x
,
Ys
...
ys
)
{
{
static_assert
(
sizeof
...(
Ys
)
>
0
,
"not enough argument"
);
static_assert
(
sizeof
...(
Ys
)
>
0
,
"not enough argument"
);
return
max
(
x
,
max
(
ys
...));
return
max
(
x
,
max
(
ys
...));
}
}
template
<
typename
T
>
template
<
typename
T
>
CK_TILE_HOST_DEVICE
constexpr
T
min
(
T
x
)
{
return
x
;
}
CK_TILE_HOST_DEVICE
constexpr
T
min
(
T
x
)
{
return
x
;
}
template
<
typename
T
>
template
<
typename
T
>
CK_TILE_HOST
constexpr
T
min
(
T
x
,
T
y
)
{
CK_TILE_HOST
constexpr
T
min
(
T
x
,
T
y
)
{
return
x
<
y
?
x
:
y
;
return
x
<
y
?
x
:
y
;
}
}
template
<
typename
T
>
template
<
typename
T
>
CK_TILE_DEVICE
constexpr
T
min
(
T
x
,
T
y
)
{
CK_TILE_DEVICE
constexpr
T
min
(
T
x
,
T
y
)
{
return
x
<
y
?
x
:
y
;
return
x
<
y
?
x
:
y
;
}
}
template
<
>
template
<
>
CK_TILE_DEVICE
float
min
(
float
x
,
float
y
)
{
CK_TILE_DEVICE
float
min
(
float
x
,
float
y
)
{
return
__builtin_fminf
(
x
,
y
);
return
__builtin_fminf
(
x
,
y
);
}
}
template
<
>
template
<
>
CK_TILE_DEVICE
double
min
(
double
x
,
double
y
)
{
CK_TILE_DEVICE
double
min
(
double
x
,
double
y
)
{
return
__builtin_fmin
(
x
,
y
);
return
__builtin_fmin
(
x
,
y
);
}
}
template
<
typename
X
,
typename
...
Ys
>
template
<
typename
X
,
typename
...
Ys
>
CK_TILE_HOST_DEVICE
constexpr
auto
min
(
X
x
,
Ys
...
ys
)
CK_TILE_HOST_DEVICE
constexpr
auto
min
(
X
x
,
Ys
...
ys
)
{
{
static_assert
(
sizeof
...(
Ys
)
>
0
,
"not enough argument"
);
static_assert
(
sizeof
...(
Ys
)
>
0
,
"not enough argument"
);
return
min
(
x
,
min
(
ys
...));
return
min
(
x
,
min
(
ys
...));
}
}
}
}
// namespace ck_tile
src/tl_templates/dcu_hip/debug.h
View file @
8f4628e0
...
@@ -189,4 +189,3 @@ __device__ void debug_print_buffer_value<double>(const char *msg,
...
@@ -189,4 +189,3 @@ __device__ void debug_print_buffer_value<double>(const char *msg,
(
int
)
threadIdx
.
x
,
(
int
)
threadIdx
.
y
,
(
int
)
threadIdx
.
z
,
safe_buf_name
,
(
int
)
threadIdx
.
x
,
(
int
)
threadIdx
.
y
,
(
int
)
threadIdx
.
z
,
safe_buf_name
,
index
,
var
);
index
,
var
);
}
}
src/tl_templates/dcu_hip/gemm.h
View file @
8f4628e0
...
@@ -69,7 +69,7 @@ template <int M, int N, int K, int num_warp_n, int num_warp_m, bool TransposeA,
...
@@ -69,7 +69,7 @@ template <int M, int N, int K, int num_warp_n, int num_warp_m, bool TransposeA,
typename
B_type
,
typename
C_type
,
typename
AccDataType
=
float
>
typename
B_type
,
typename
C_type
,
typename
AccDataType
=
float
>
class
GemmTensorOp
{
class
GemmTensorOp
{
public:
public:
//static_assert(!clear_accum, "clear_accum=true is not supported yet");
//
static_assert(!clear_accum, "clear_accum=true is not supported yet");
static
constexpr
int
micro_size_x
=
16
;
static
constexpr
int
micro_size_x
=
16
;
static
constexpr
int
micro_size_y
=
16
;
static
constexpr
int
micro_size_y
=
16
;
...
@@ -156,8 +156,8 @@ public:
...
@@ -156,8 +156,8 @@ public:
C_type
*
C_local
)
{
C_type
*
C_local
)
{
auto
tid
=
threadIdx
.
x
;
auto
tid
=
threadIdx
.
x
;
auto
warp_id
=
tid
/
warp_size
;
auto
warp_id
=
tid
/
warp_size
;
auto
warp_
n
=
warp_id
/
block_
row
_warps
;
auto
warp_
m
=
warp_id
/
block_
col
_warps
;
auto
warp_
m
=
warp_id
%
block_
row
_warps
;
auto
warp_
n
=
warp_id
%
block_
col
_warps
;
auto
warp_row_tiles
=
warp_rows
*
micro_size_x
;
auto
warp_row_tiles
=
warp_rows
*
micro_size_x
;
auto
warp_col_tiles
=
warp_cols
*
micro_size_y
;
auto
warp_col_tiles
=
warp_cols
*
micro_size_y
;
...
@@ -165,8 +165,8 @@ public:
...
@@ -165,8 +165,8 @@ public:
auto
tx
=
lane_id
;
auto
tx
=
lane_id
;
auto
alane_id
=
lane_id
;
auto
alane_id
=
lane_id
;
auto
blane_id
=
((
lane_id
&
15
)
>>
2
)
+
((
lane_id
&
3
)
<<
2
)
+
((
lane_id
>>
4
)
<<
4
);
auto
blane_id
=
((
lane_id
&
15
)
>>
2
)
+
((
lane_id
&
3
)
<<
2
)
+
((
lane_id
>>
4
)
<<
4
);
constexpr
auto
local_size_a
=
(
micro_size_x
*
micro_size_k
)
/
warp_size
;
constexpr
auto
local_size_a
=
(
micro_size_x
*
micro_size_k
)
/
warp_size
;
constexpr
auto
local_size_b
=
(
micro_size_y
*
micro_size_k
)
/
warp_size
;
constexpr
auto
local_size_b
=
(
micro_size_y
*
micro_size_k
)
/
warp_size
;
...
@@ -194,7 +194,6 @@ public:
...
@@ -194,7 +194,6 @@ public:
B_local
[
i
*
kPack
*
local_size_b
+
local_id
]
=
B_local
[
i
*
kPack
*
local_size_b
+
local_id
]
=
B_shared
[
make_swizzle_layout
<
last_dim_b
,
sizeof
(
B_type
)
>
(
B_shared
[
make_swizzle_layout
<
last_dim_b
,
sizeof
(
B_type
)
>
(
r
+
row
,
l
+
col
)];
r
+
row
,
l
+
col
)];
}
}
}
}
}
}
...
@@ -237,8 +236,8 @@ public:
...
@@ -237,8 +236,8 @@ public:
C_type
*
C_local
)
{
C_type
*
C_local
)
{
auto
tid
=
threadIdx
.
x
;
auto
tid
=
threadIdx
.
x
;
auto
warp_id
=
tid
/
warp_size
;
auto
warp_id
=
tid
/
warp_size
;
auto
warp_
n
=
warp_id
/
block_
row
_warps
;
auto
warp_
m
=
warp_id
/
block_
col
_warps
;
auto
warp_
m
=
warp_id
%
block_
row
_warps
;
auto
warp_
n
=
warp_id
%
block_
col
_warps
;
auto
warp_row_tiles
=
warp_rows
*
micro_size_x
;
auto
warp_row_tiles
=
warp_rows
*
micro_size_x
;
auto
warp_col_tiles
=
warp_cols
*
micro_size_y
;
auto
warp_col_tiles
=
warp_cols
*
micro_size_y
;
...
@@ -246,7 +245,8 @@ public:
...
@@ -246,7 +245,8 @@ public:
auto
tx
=
lane_id
;
auto
tx
=
lane_id
;
auto
alane_id
=
lane_id
;
auto
alane_id
=
lane_id
;
auto
blane_id
=
((
lane_id
&
15
)
>>
2
)
+
((
lane_id
&
3
)
<<
2
)
+
((
lane_id
>>
4
)
<<
4
);
auto
blane_id
=
((
lane_id
&
15
)
>>
2
)
+
((
lane_id
&
3
)
<<
2
)
+
((
lane_id
>>
4
)
<<
4
);
constexpr
auto
local_size_a
=
(
micro_size_x
*
micro_size_k
)
/
warp_size
;
constexpr
auto
local_size_a
=
(
micro_size_x
*
micro_size_k
)
/
warp_size
;
constexpr
auto
local_size_b
=
(
micro_size_y
*
micro_size_k
)
/
warp_size
;
constexpr
auto
local_size_b
=
(
micro_size_y
*
micro_size_k
)
/
warp_size
;
...
@@ -321,4 +321,3 @@ TL_DEVICE void gemm_rs(A_type *pA, B_type *pB, C_type *accum) {
...
@@ -321,4 +321,3 @@ TL_DEVICE void gemm_rs(A_type *pA, B_type *pB, C_type *accum) {
}
}
}
// namespace tl
}
// namespace tl
src/tl_templates/dcu_hip/hip_fp8.h
View file @
8f4628e0
...
@@ -72,4 +72,3 @@ __device__ fp8_e4_8_t make_fp8_e4_8_t(fp8_e4_t x, fp8_e4_t y, fp8_e4_t z,
...
@@ -72,4 +72,3 @@ __device__ fp8_e4_8_t make_fp8_e4_8_t(fp8_e4_t x, fp8_e4_t y, fp8_e4_t z,
res
.
y
=
*
reinterpret_cast
<
fp8_e4_4_t
*>
(
&
b
);
res
.
y
=
*
reinterpret_cast
<
fp8_e4_4_t
*>
(
&
b
);
return
res
;
return
res
;
}
}
src/tl_templates/dcu_hip/reduce.h
View file @
8f4628e0
...
@@ -22,14 +22,11 @@ struct MinOp {
...
@@ -22,14 +22,11 @@ struct MinOp {
}
}
};
};
// Detect half types
// Detect half types
template
<
typename
T
>
template
<
typename
T
>
struct
is_half_type
:
std
::
false_type
{};
struct
is_half_type
:
std
::
false_type
{};
template
<
>
template
<
>
struct
is_half_type
<
__half
>
:
std
::
true_type
{};
struct
is_half_type
<
__half
>
:
std
::
true_type
{};
template
<
>
template
<
>
struct
is_half_type
<
_Float16
>
:
std
::
true_type
{};
struct
is_half_type
<
_Float16
>
:
std
::
true_type
{};
template
<
typename
T
>
template
<
typename
T
>
inline
constexpr
bool
is_half_v
=
is_half_type
<
std
::
decay_t
<
T
>>::
value
;
inline
constexpr
bool
is_half_v
=
is_half_type
<
std
::
decay_t
<
T
>>::
value
;
...
@@ -56,7 +53,10 @@ struct AllReduce {
...
@@ -56,7 +53,10 @@ struct AllReduce {
if
constexpr
(
std
::
is_same_v
<
std
::
decay_t
<
T
>
,
__half
>
)
{
if
constexpr
(
std
::
is_same_v
<
std
::
decay_t
<
T
>
,
__half
>
)
{
x_raw
=
__half_as_ushort
(
x
);
x_raw
=
__half_as_ushort
(
x
);
}
else
{
// _Float16
}
else
{
// _Float16
union
{
_Float16
f
;
unsigned
short
s
;
}
u
;
union
{
_Float16
f
;
unsigned
short
s
;
}
u
;
u
.
f
=
x
;
u
.
f
=
x
;
x_raw
=
u
.
s
;
x_raw
=
u
.
s
;
}
}
...
@@ -67,7 +67,10 @@ struct AllReduce {
...
@@ -67,7 +67,10 @@ struct AllReduce {
if
constexpr
(
std
::
is_same_v
<
std
::
decay_t
<
T
>
,
__half
>
)
{
if
constexpr
(
std
::
is_same_v
<
std
::
decay_t
<
T
>
,
__half
>
)
{
shuffled_x
=
__ushort_as_half
(
shuffled_raw
);
shuffled_x
=
__ushort_as_half
(
shuffled_raw
);
}
else
{
// _Float16
}
else
{
// _Float16
union
{
unsigned
short
s
;
_Float16
f
;
}
u
;
union
{
unsigned
short
s
;
_Float16
f
;
}
u
;
u
.
s
=
shuffled_raw
;
u
.
s
=
shuffled_raw
;
shuffled_x
=
u
.
f
;
shuffled_x
=
u
.
f
;
}
}
...
@@ -116,7 +119,7 @@ template <int threads, int Axis = 0, bool reverse = false> struct CumSum2D {
...
@@ -116,7 +119,7 @@ template <int threads, int Axis = 0, bool reverse = false> struct CumSum2D {
T
val
=
(
col
<
W
)
?
src
[
real_row
*
W
+
real_col
]
:
(
T
)
0
;
T
val
=
(
col
<
W
)
?
src
[
real_row
*
W
+
real_col
]
:
(
T
)
0
;
#pragma unroll
#pragma unroll
for
(
int
off
=
1
;
off
<
SEG
;
off
<<=
1
)
{
for
(
int
off
=
1
;
off
<
SEG
;
off
<<=
1
)
{
T
n
=
(
T
)
__shfl_down_sync
(
MASK
,
val
,
off
);
T
n
=
(
T
)
__shfl_down_sync
(
MASK
,
val
,
off
);
if
(
lane
<
SEG
-
off
)
if
(
lane
<
SEG
-
off
)
...
@@ -142,7 +145,7 @@ template <int threads, int Axis = 0, bool reverse = false> struct CumSum2D {
...
@@ -142,7 +145,7 @@ template <int threads, int Axis = 0, bool reverse = false> struct CumSum2D {
T
val
=
(
col
<
W
)
?
src
[
real_row
*
W
+
real_col
]
:
(
T
)
0
;
T
val
=
(
col
<
W
)
?
src
[
real_row
*
W
+
real_col
]
:
(
T
)
0
;
#pragma unroll
#pragma unroll
for
(
int
off
=
1
;
off
<
SEG
;
off
<<=
1
)
{
for
(
int
off
=
1
;
off
<
SEG
;
off
<<=
1
)
{
T
n
=
(
T
)
__shfl_up_sync
(
MASK
,
val
,
off
);
T
n
=
(
T
)
__shfl_up_sync
(
MASK
,
val
,
off
);
if
(
lane
>=
off
)
if
(
lane
>=
off
)
...
@@ -164,4 +167,3 @@ template <int threads, int Axis = 0, bool reverse = false> struct CumSum2D {
...
@@ -164,4 +167,3 @@ template <int threads, int Axis = 0, bool reverse = false> struct CumSum2D {
}
}
};
};
}
// namespace tl
}
// namespace tl
src/tl_templates/dcu_hip/threadblock_swizzle.h
View file @
8f4628e0
...
@@ -43,4 +43,3 @@ template <int panel_width> TL_DEVICE dim3 rasterization2DColumn() {
...
@@ -43,4 +43,3 @@ template <int panel_width> TL_DEVICE dim3 rasterization2DColumn() {
}
}
}
// namespace tl
}
// namespace tl
testing/python/dcu/test_tilelang_gemm_mmac_intrinsic.py
View file @
8f4628e0
...
@@ -12,6 +12,7 @@ from tilelang.transform import simplify_prim_func
...
@@ -12,6 +12,7 @@ from tilelang.transform import simplify_prim_func
tilelang
.
testing
.
set_random_seed
(
0
)
tilelang
.
testing
.
set_random_seed
(
0
)
tilelang
.
disable_cache
()
tilelang
.
disable_cache
()
def
make_swizzle_layout
(
shared_buf
):
def
make_swizzle_layout
(
shared_buf
):
dtype
=
shared_buf
.
dtype
dtype
=
shared_buf
.
dtype
shape
=
shared_buf
.
shape
shape
=
shared_buf
.
shape
...
@@ -63,7 +64,7 @@ def tl_matmul(
...
@@ -63,7 +64,7 @@ def tl_matmul(
chunk
=
32
*
k_pack
chunk
=
32
*
k_pack
shared_scope
=
"shared"
shared_scope
=
"shared"
cache_write_shared
=
False
#
cache_write_shared = False
block_M
=
block_row_warps
*
warp_row_tiles
block_M
=
block_row_warps
*
warp_row_tiles
block_N
=
block_col_warps
*
warp_col_tiles
block_N
=
block_col_warps
*
warp_col_tiles
...
...
tilelang/intrinsics/mmac_macro_generator.py
View file @
8f4628e0
from
__future__
import
annotations
from
tilelang
import
tvm
as
tvm
from
tilelang
import
tvm
as
tvm
import
tilelang.language
as
T
import
tilelang.language
as
T
from
typing
import
Tuple
from
tvm
import
DataType
from
tvm
import
DataType
from
tvm.tir
import
PrimExpr
from
tvm.tir
import
PrimExpr
from
tvm.runtime
import
convert
from
tvm.runtime
import
convert
from
typing
import
Optional
from
.utils
import
(
from
.utils
import
(
mfma_store_index_map
,)
mfma_store_index_map
,)
lift
=
convert
lift
=
convert
class
MatrixCoreIntrinEmitter
(
object
)
:
class
MatrixCoreIntrinEmitter
:
"""
"""
To eliminate Python syntax within TIR Macro.
To eliminate Python syntax within TIR Macro.
"""
"""
...
@@ -51,9 +50,9 @@ class MatrixCoreIntrinEmitter(object):
...
@@ -51,9 +50,9 @@ class MatrixCoreIntrinEmitter(object):
chunk
:
int
=
16
,
chunk
:
int
=
16
,
reduce_k
:
int
=
1
,
reduce_k
:
int
=
1
,
num_elems_per_byte
:
int
=
1
,
num_elems_per_byte
:
int
=
1
,
k_pack
:
Optional
[
int
]
=
None
,
k_pack
:
int
|
None
=
None
,
is_m_first
:
Optional
[
bool
]
=
False
,
is_m_first
:
bool
|
None
=
False
,
b_preshuffle
:
Optional
[
bool
]
=
False
,
b_preshuffle
:
bool
|
None
=
False
,
):
):
self
.
a_dtype
=
a_dtype
self
.
a_dtype
=
a_dtype
self
.
b_dtype
=
b_dtype
self
.
b_dtype
=
b_dtype
...
@@ -119,7 +118,7 @@ class MatrixCoreIntrinEmitter(object):
...
@@ -119,7 +118,7 @@ class MatrixCoreIntrinEmitter(object):
"float16"
:
"f16"
,
"float16"
:
"f16"
,
"float32"
:
"f32"
,
"float32"
:
"f32"
,
"int8"
:
"i8"
,
"int8"
:
"i8"
,
"bfloat16"
:
"bf16"
"bfloat16"
:
"bf16"
}[
in_dtype
]
}[
in_dtype
]
self
.
mmac_suffix
=
f
"
{
out_dtype_abbrv
}
_
{
M_DIM
}
x
{
N_DIM
}
x
{
k_dim
}{
in_dtype_abbrv
}
"
self
.
mmac_suffix
=
f
"
{
out_dtype_abbrv
}
_
{
M_DIM
}
x
{
N_DIM
}
x
{
k_dim
}{
in_dtype_abbrv
}
"
...
@@ -129,15 +128,15 @@ class MatrixCoreIntrinEmitter(object):
...
@@ -129,15 +128,15 @@ class MatrixCoreIntrinEmitter(object):
self
.
micro_size_y
=
n_dim
self
.
micro_size_y
=
n_dim
self
.
micro_size_k
=
k_dim
self
.
micro_size_k
=
k_dim
def
_initialize_k_pack
(
self
,
k_pack
:
Optional
[
int
]
=
None
):
def
_initialize_k_pack
(
self
,
k_pack
:
int
|
None
=
None
):
if
k_pack
is
not
None
:
if
k_pack
is
not
None
:
self
.
k_pack
=
k_pack
self
.
k_pack
=
k_pack
def
_initialize_is_m_first
(
self
,
is_m_first
:
Optional
[
bool
]
=
False
):
def
_initialize_is_m_first
(
self
,
is_m_first
:
bool
|
None
=
False
):
if
is_m_first
is
not
None
:
if
is_m_first
is
not
None
:
self
.
is_m_first
=
is_m_first
self
.
is_m_first
=
is_m_first
def
_initialize_b_preshuffle
(
self
,
b_preshuffle
:
Optional
[
bool
]
=
False
):
def
_initialize_b_preshuffle
(
self
,
b_preshuffle
:
bool
|
None
=
False
):
if
b_preshuffle
is
not
None
:
if
b_preshuffle
is
not
None
:
self
.
b_preshuffle
=
b_preshuffle
self
.
b_preshuffle
=
b_preshuffle
...
@@ -197,7 +196,7 @@ class MatrixCoreIntrinEmitter(object):
...
@@ -197,7 +196,7 @@ class MatrixCoreIntrinEmitter(object):
def
extract_thread_binding
(
self
,
def
extract_thread_binding
(
self
,
thread_id
,
thread_id
,
is_m_first
=
None
)
->
T
uple
[
PrimExpr
,
PrimExpr
,
PrimExpr
]:
is_m_first
=
None
)
->
t
uple
[
PrimExpr
,
PrimExpr
,
PrimExpr
]:
'''
'''
is_m_first: True if the thread binding is in the form of (tx, warp_n, warp_m)
is_m_first: True if the thread binding is in the form of (tx, warp_n, warp_m)
which represents [warp_size, block_row_warps (split n), block_col_warps (split m)]
which represents [warp_size, block_row_warps (split n), block_col_warps (split m)]
...
@@ -290,7 +289,9 @@ class MatrixCoreIntrinEmitter(object):
...
@@ -290,7 +289,9 @@ class MatrixCoreIntrinEmitter(object):
if
is_transposed
:
if
is_transposed
:
for
j
in
T
.
serial
(
warp_cols
):
for
j
in
T
.
serial
(
warp_cols
):
for
local_id
in
T
.
vectorized
(
k_pack
*
local_size_b
):
for
local_id
in
T
.
vectorized
(
k_pack
*
local_size_b
):
row
,
col
=
T
.
meta_var
(
reverse_index_map
((
tx
&
15
)
//
4
+
(
tx
&
3
)
*
4
+
(
tx
//
16
)
*
16
,
local_id
))
row
,
col
=
T
.
meta_var
(
reverse_index_map
((
tx
&
15
)
//
4
+
(
tx
&
3
)
*
4
+
(
tx
//
16
)
*
16
,
local_id
))
l
,
r
=
(
l
,
r
=
(
warp_n
*
warp_col_tiles
+
j
*
micro_size_y
,
warp_n
*
warp_col_tiles
+
j
*
micro_size_y
,
rk
*
chunk
+
ki
*
(
k_pack
*
micro_size_k
),
rk
*
chunk
+
ki
*
(
k_pack
*
micro_size_k
),
...
@@ -301,7 +302,9 @@ class MatrixCoreIntrinEmitter(object):
...
@@ -301,7 +302,9 @@ class MatrixCoreIntrinEmitter(object):
else
:
else
:
for
j
in
T
.
serial
(
warp_cols
):
for
j
in
T
.
serial
(
warp_cols
):
for
local_id
in
T
.
vectorized
(
k_pack
*
local_size_b
):
for
local_id
in
T
.
vectorized
(
k_pack
*
local_size_b
):
row
,
col
=
T
.
meta_var
(
reverse_index_map
((
tx
&
15
)
//
4
+
(
tx
&
3
)
*
4
+
(
tx
//
16
)
*
16
,
local_id
))
row
,
col
=
T
.
meta_var
(
reverse_index_map
((
tx
&
15
)
//
4
+
(
tx
&
3
)
*
4
+
(
tx
//
16
)
*
16
,
local_id
))
l
,
r
=
(
l
,
r
=
(
rk
*
chunk
+
ki
*
(
k_pack
*
micro_size_k
),
rk
*
chunk
+
ki
*
(
k_pack
*
micro_size_k
),
warp_n
*
warp_col_tiles
+
j
*
micro_size_y
,
warp_n
*
warp_col_tiles
+
j
*
micro_size_y
,
...
@@ -412,10 +415,10 @@ class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter):
...
@@ -412,10 +415,10 @@ class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter):
chunk
:
int
=
16
,
chunk
:
int
=
16
,
reduce_k
:
int
=
1
,
reduce_k
:
int
=
1
,
num_elems_per_byte
:
int
=
1
,
num_elems_per_byte
:
int
=
1
,
k_pack
:
Optional
[
int
]
=
None
,
k_pack
:
int
|
None
=
None
,
is_m_first
:
Optional
[
bool
]
=
False
,
is_m_first
:
bool
|
None
=
False
,
a_preshuffle
:
Optional
[
bool
]
=
False
,
a_preshuffle
:
bool
|
None
=
False
,
b_preshuffle
:
Optional
[
bool
]
=
False
,
b_preshuffle
:
bool
|
None
=
False
,
):
):
self
.
a_dtype
=
a_dtype
self
.
a_dtype
=
a_dtype
...
@@ -579,7 +582,9 @@ class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter):
...
@@ -579,7 +582,9 @@ class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter):
if
is_transposed
:
if
is_transposed
:
for
j
in
T
.
serial
(
warp_cols
):
for
j
in
T
.
serial
(
warp_cols
):
for
local_id
in
T
.
vectorized
(
k_pack
*
local_size_b
):
for
local_id
in
T
.
vectorized
(
k_pack
*
local_size_b
):
row
,
col
=
T
.
meta_var
(
reverse_index_map
(((
tx
&
15
)
>>
2
)
+
((
tx
&
3
)
<<
2
)
+
((
tx
>>
4
)
<<
4
),
local_id
))
row
,
col
=
T
.
meta_var
(
reverse_index_map
(((
tx
&
15
)
>>
2
)
+
((
tx
&
3
)
<<
2
)
+
((
tx
>>
4
)
<<
4
),
local_id
))
l
,
r
=
(
l
,
r
=
(
warp_n
*
warp_cols
+
j
,
warp_n
*
warp_cols
+
j
,
rk
*
(
chunk
//
micro_size_k
)
+
ki
,
rk
*
(
chunk
//
micro_size_k
)
+
ki
,
...
@@ -589,7 +594,9 @@ class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter):
...
@@ -589,7 +594,9 @@ class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter):
else
:
else
:
for
j
in
T
.
serial
(
warp_cols
):
for
j
in
T
.
serial
(
warp_cols
):
for
local_id
in
T
.
vectorized
(
k_pack
*
local_size_b
):
for
local_id
in
T
.
vectorized
(
k_pack
*
local_size_b
):
row
,
col
=
T
.
meta_var
(
reverse_index_map
(((
tx
&
15
)
>>
2
)
+
((
tx
&
3
)
<<
2
)
+
((
tx
>>
4
)
<<
4
),
local_id
))
row
,
col
=
T
.
meta_var
(
reverse_index_map
(((
tx
&
15
)
>>
2
)
+
((
tx
&
3
)
<<
2
)
+
((
tx
>>
4
)
<<
4
),
local_id
))
l
,
r
=
(
l
,
r
=
(
rk
*
(
chunk
//
micro_size_k
)
+
ki
,
rk
*
(
chunk
//
micro_size_k
)
+
ki
,
warp_n
*
warp_cols
+
j
,
warp_n
*
warp_cols
+
j
,
...
@@ -600,4 +607,3 @@ class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter):
...
@@ -600,4 +607,3 @@ class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter):
return
_warp_ldmatrix_b_global
(
B_local_buf
,
B_buf
,
ki
,
thread_binding
,
return
_warp_ldmatrix_b_global
(
B_local_buf
,
B_buf
,
ki
,
thread_binding
,
rk
)
if
is_global
else
_warp_ldmatrix_b_shared
(
rk
)
if
is_global
else
_warp_ldmatrix_b_shared
(
B_local_buf
,
B_buf
,
ki
,
thread_binding
,
rk
)
B_local_buf
,
B_buf
,
ki
,
thread_binding
,
rk
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment