Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
8f4628e0
Commit
8f4628e0
authored
Nov 11, 2025
by
qisan
Browse files
[Bugfix] Using a new data layout and the performance of NN gemm exceeds rocblas
parent
9a640856
Changes
16
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
124 additions
and
158 deletions
+124
-158
examples/gemm/example_gemm_intrinsics_dcu.py
examples/gemm/example_gemm_intrinsics_dcu.py
+4
-5
src/layout/gemm_layouts.cc
src/layout/gemm_layouts.cc
+5
-4
src/layout/layout.h
src/layout/layout.h
+2
-2
src/op/gemm.cc
src/op/gemm.cc
+5
-6
src/target/intrin_rule_hip.cc
src/target/intrin_rule_hip.cc
+0
-1
src/target/utils.cc
src/target/utils.cc
+0
-1
src/tl_templates/dcu_hip/common.h
src/tl_templates/dcu_hip/common.h
+18
-24
src/tl_templates/dcu_hip/copy.h
src/tl_templates/dcu_hip/copy.h
+0
-1
src/tl_templates/dcu_hip/core.hpp
src/tl_templates/dcu_hip/core.hpp
+28
-57
src/tl_templates/dcu_hip/debug.h
src/tl_templates/dcu_hip/debug.h
+0
-1
src/tl_templates/dcu_hip/gemm.h
src/tl_templates/dcu_hip/gemm.h
+15
-16
src/tl_templates/dcu_hip/hip_fp8.h
src/tl_templates/dcu_hip/hip_fp8.h
+0
-1
src/tl_templates/dcu_hip/reduce.h
src/tl_templates/dcu_hip/reduce.h
+13
-11
src/tl_templates/dcu_hip/threadblock_swizzle.h
src/tl_templates/dcu_hip/threadblock_swizzle.h
+0
-1
testing/python/dcu/test_tilelang_gemm_mmac_intrinsic.py
testing/python/dcu/test_tilelang_gemm_mmac_intrinsic.py
+3
-2
tilelang/intrinsics/mmac_macro_generator.py
tilelang/intrinsics/mmac_macro_generator.py
+31
-25
No files found.
examples/gemm/example_gemm_intrinsics_dcu.py
View file @
8f4628e0
...
...
@@ -8,7 +8,8 @@ from tilelang.intrinsics.mmac_macro_generator import (
from
tilelang.transform
import
simplify_prim_func
from
tilelang
import
disable_cache
disable_cache
()
disable_cache
()
def
make_swizzle_layout
(
shared_buf
):
dtype
=
shared_buf
.
dtype
...
...
@@ -81,7 +82,7 @@ def tl_matmul(
threads
=
warp_size
*
(
block_row_warps
*
block_col_warps
)
local_size_a
=
(
micro_size_x
*
micro_size_k
)
//
warp_size
local_size_b
=
(
micro_size_y
*
micro_size_k
)
//
warp_size
local_size_c
=
(
micro_size_x
*
micro_size_y
)
//
warp_size
local_size_c
=
(
micro_size_x
*
micro_size_y
)
//
warp_size
warp_rows
=
warp_row_tiles
//
micro_size_x
warp_cols
=
warp_col_tiles
//
micro_size_y
...
...
@@ -152,7 +153,7 @@ def tl_matmul(
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
C
[
by
*
block_M
+
i
,
bx
*
block_N
+
j
]
=
C_shared
[
j
//
micro_size_y
,
i
//
micro_size_x
,
i
//
micro_size_x
,
i
%
micro_size_x
,
j
%
micro_size_y
,
]
...
...
@@ -186,5 +187,3 @@ def main():
if
__name__
==
"__main__"
:
main
()
src/layout/gemm_layouts.cc
View file @
8f4628e0
...
...
@@ -157,8 +157,8 @@ Fragment makeGemmSparseFragmentC(const int block_m, const int block_n,
}
Fragment
makeGemmFragmentCDCU
(
const
int
block_m
,
const
int
block_n
,
const
int
warp_m
,
const
int
warp_n
,
const
int
element_size
)
{
const
int
warp_m
,
const
int
warp_n
,
const
int
element_size
)
{
if
(
element_size
==
64
)
LOG
(
FATAL
)
<<
"Not supported"
;
ICHECK
(
block_m
%
warp_m
==
0
);
...
...
@@ -169,7 +169,7 @@ Fragment makeGemmFragmentCDCU(const int block_m, const int block_n,
auto
warp_layout
=
base_layout
->
Repeat
({
warp_m
/
16
,
warp_n
/
16
},
false
,
false
);
auto
block_layout
=
warp_layout
->
Repeat
({
block_m
/
warp_m
,
block_n
/
warp_n
},
true
,
tru
e
);
warp_layout
->
Repeat
({
block_m
/
warp_m
,
block_n
/
warp_n
},
true
,
fals
e
);
return
block_layout
;
}
...
...
@@ -747,7 +747,8 @@ Layout makeGemmABLayout(int mat_stride, int mat_continuous, int continuity,
if
(
!
k_inner
&&
element_size
==
8
)
// int8 KxN
return
makeGemmABLayoutPadded
(
mat_stride
,
mat_continuous
,
element_size
);
else
if
(
mat_continuous
%
(
vector_size
*
8
)
==
0
)
// return makeHalfBankSwizzleLayout(mat_stride, mat_continuous, element_size);
// return makeHalfBankSwizzleLayout(mat_stride, mat_continuous,
// element_size);
return
makeFullBankSwizzleLayout
(
mat_stride
,
mat_continuous
,
element_size
);
else
if
(
mat_continuous
%
(
vector_size
*
4
)
==
0
)
return
makeHalfBankSwizzleLayout
(
mat_stride
,
mat_continuous
,
element_size
);
...
...
src/layout/layout.h
View file @
8f4628e0
...
...
@@ -151,8 +151,8 @@ Fragment makeGemmFragmentCCDNA(const int block_m, const int block_n,
const
int
warp_m
,
const
int
warp_n
,
const
int
element_size
);
Fragment
makeGemmFragmentCDCU
(
const
int
block_m
,
const
int
block_n
,
const
int
warp_m
,
const
int
warp_n
,
const
int
element_size
);
const
int
warp_m
,
const
int
warp_n
,
const
int
element_size
);
Fragment
makeGemmFragmentCHopper
(
const
int
block_m
,
const
int
block_n
,
const
int
warp_m
,
const
int
warp_n
,
const
int
element_size
);
...
...
src/op/gemm.cc
View file @
8f4628e0
...
...
@@ -4,8 +4,8 @@
*/
#include "gemm.h"
#include <fstream>
#include "builtin.h"
#include <fstream>
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>
...
...
@@ -828,15 +828,14 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T,
ICHECK
(
C
.
scope
()
==
"local.fragment"
)
<<
"CDNA gemm (FMMA) only supports C in local.fragment scope, got "
<<
C
.
scope
();
if
(
TargetIsDCU
(
T
.
target
))
{
if
(
TargetIsDCU
(
T
.
target
))
{
auto
fragment
=
makeGemmFragmentCDCU
(
M
,
N
,
M
/
warp_m
,
N
/
warp_n
,
C
->
dtype
.
bits
());
results
.
Set
(
C
,
fragment
->
BindThreadRange
(
thread_range
));
}
else
{
auto
fragment
=
makeGemmFragmentCCDNA
(
M
,
N
,
M
/
warp_m
,
N
/
warp_n
,
C
->
dtype
.
bits
());
results
.
Set
(
C
,
fragment
->
BindThreadRange
(
thread_range
));
auto
fragment
=
makeGemmFragmentCCDNA
(
M
,
N
,
M
/
warp_m
,
N
/
warp_n
,
C
->
dtype
.
bits
());
results
.
Set
(
C
,
fragment
->
BindThreadRange
(
thread_range
));
}
if
(
A
.
scope
()
==
"shared"
||
A
.
scope
()
==
"shared.dyn"
)
{
...
...
src/target/intrin_rule_hip.cc
View file @
8f4628e0
...
...
@@ -249,7 +249,6 @@ TVM_REGISTER_OP("tir.hip.__shfl")
.
set_attr
<
TCallEffectKind
>
(
"TCallEffectKind"
,
Integer
(
CallEffectKind
::
kOpaque
));
TVM_REGISTER_OP
(
"tir.hip.__shfl_sync"
)
.
set_num_inputs
(
4
)
.
add_argument
(
"mask"
,
"Expr"
,
"The thread mask."
)
...
...
src/target/utils.cc
View file @
8f4628e0
...
...
@@ -5,7 +5,6 @@
#include "utils.h"
namespace
tvm
{
namespace
tl
{
...
...
src/tl_templates/dcu_hip/common.h
View file @
8f4628e0
#pragma once
#include "core.hpp"
#include <hip/hip_bf16.h>
#include <hip/hip_fp16.h>
...
...
@@ -106,41 +105,36 @@ TL_DEVICE unsigned __pack_bfloat162(const bfloat16_t x, const bfloat16_t y) {
return
(
v1
<<
16
)
|
v0
;
}
template
<
typename
T
>
struct
is_half_type
:
std
::
false_type
{};
template
<
typename
T
>
struct
is_half_type
:
std
::
false_type
{};
template
<
>
struct
is_half_type
<
__half
>
:
std
::
true_type
{};
template
<
>
struct
is_half_type
<
__half
>
:
std
::
true_type
{};
template
<
>
struct
is_half_type
<
half_t
>
:
std
::
true_type
{};
template
<
>
struct
is_half_type
<
half_t
>
:
std
::
true_type
{};
template
<
typename
T
>
inline
constexpr
bool
is_half_v
=
is_half_type
<
std
::
decay_t
<
T
>>::
value
;
template
<
typename
T1
,
typename
T2
>
TL_DEVICE
void
AtomicAdd
(
T1
*
address
,
T2
val
)
{
if
constexpr
(
is_half_v
<
T1
>
)
{
__half
*
addr
=
reinterpret_cast
<
__half
*>
(
address
);
__half
hval
=
__float2half
(
static_cast
<
float
>
(
val
));
atomicAdd
(
addr
,
hval
);
}
else
{
atomicAdd
(
address
,
static_cast
<
T1
>
(
val
));
}
TL_DEVICE
void
AtomicAdd
(
T1
*
address
,
T2
val
)
{
if
constexpr
(
is_half_v
<
T1
>
)
{
__half
*
addr
=
reinterpret_cast
<
__half
*>
(
address
);
__half
hval
=
__float2half
(
static_cast
<
float
>
(
val
));
atomicAdd
(
addr
,
hval
);
}
else
{
atomicAdd
(
address
,
static_cast
<
T1
>
(
val
));
}
}
template
<
typename
T1
,
typename
T2
>
TL_DEVICE
void
AtomicAdd
(
T1
&
ref
,
T2
val
)
{
AtomicAdd
(
&
ref
,
val
);
template
<
typename
T1
,
typename
T2
>
TL_DEVICE
void
AtomicAdd
(
T1
&
ref
,
T2
val
)
{
AtomicAdd
(
&
ref
,
val
);
}
template
<
typename
T1
,
typename
T2
>
TL_DEVICE
T1
AtomicAddRet
(
T1
&
ref
,
T2
val
)
{
return
atomicAdd
(
&
ref
,
static_cast
<
T1
>
(
val
));
}
template
<
typename
T
>
TL_DEVICE
void
AtomicAddx4
(
T
*
ref
,
const
T
val
[
4
])
{
atomicAdd
(
&
ref
[
0
],
val
[
0
]);
atomicAdd
(
&
ref
[
1
],
val
[
1
]);
atomicAdd
(
&
ref
[
2
],
val
[
2
]);
atomicAdd
(
&
ref
[
3
],
val
[
3
]);
template
<
typename
T
>
TL_DEVICE
void
AtomicAddx4
(
T
*
ref
,
const
T
val
[
4
])
{
atomicAdd
(
&
ref
[
0
],
val
[
0
]);
atomicAdd
(
&
ref
[
1
],
val
[
1
]);
atomicAdd
(
&
ref
[
2
],
val
[
2
]);
atomicAdd
(
&
ref
[
3
],
val
[
3
]);
}
\ No newline at end of file
src/tl_templates/dcu_hip/copy.h
View file @
8f4628e0
...
...
@@ -108,4 +108,3 @@ TL_DEVICE void cp_async_gs_conditional(void *lds_base_ptr,
}
}
// namespace tl
src/tl_templates/dcu_hip/core.hpp
View file @
8f4628e0
...
...
@@ -14,7 +14,7 @@
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0xffffffff
#elif defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || \
#elif defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) ||
\
defined(__gfx9__) // for GPU code
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x00020000
#elif defined(__gfx103__) // for GPU code
...
...
@@ -25,82 +25,53 @@
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x31004000
#endif
namespace
ck_tile
{
using
int32x4_t
=
int32_t
__attribute__
((
ext_vector_type
(
4
)));
template
<
typename
T
>
CK_TILE_HOST_DEVICE
constexpr
T
max
(
T
x
)
{
return
x
;
}
namespace
ck_tile
{
using
int32x4_t
=
int32_t
__attribute__
((
ext_vector_type
(
4
)));
template
<
typename
T
>
CK_TILE_HOST_DEVICE
constexpr
T
max
(
T
x
)
{
return
x
;
}
template
<
typename
T
>
CK_TILE_HOST
constexpr
T
max
(
T
x
,
T
y
)
{
return
x
>
y
?
x
:
y
;
template
<
typename
T
>
CK_TILE_HOST
constexpr
T
max
(
T
x
,
T
y
)
{
return
x
>
y
?
x
:
y
;
}
template
<
typename
T
>
CK_TILE_DEVICE
constexpr
T
max
(
T
x
,
T
y
)
{
return
x
>
y
?
x
:
y
;
template
<
typename
T
>
CK_TILE_DEVICE
constexpr
T
max
(
T
x
,
T
y
)
{
return
x
>
y
?
x
:
y
;
}
template
<
>
CK_TILE_DEVICE
float
max
(
float
x
,
float
y
)
{
return
__builtin_fmaxf
(
x
,
y
);
// can resultin v_max3_f32
template
<
>
CK_TILE_DEVICE
float
max
(
float
x
,
float
y
)
{
return
__builtin_fmaxf
(
x
,
y
);
// can resultin v_max3_f32
}
template
<
>
CK_TILE_DEVICE
double
max
(
double
x
,
double
y
)
{
return
__builtin_fmax
(
x
,
y
);
// maybe still v_max3_f32
template
<
>
CK_TILE_DEVICE
double
max
(
double
x
,
double
y
)
{
return
__builtin_fmax
(
x
,
y
);
// maybe still v_max3_f32
}
template
<
typename
X
,
typename
...
Ys
>
CK_TILE_HOST_DEVICE
constexpr
auto
max
(
X
x
,
Ys
...
ys
)
{
static_assert
(
sizeof
...(
Ys
)
>
0
,
"not enough argument"
);
return
max
(
x
,
max
(
ys
...));
CK_TILE_HOST_DEVICE
constexpr
auto
max
(
X
x
,
Ys
...
ys
)
{
static_assert
(
sizeof
...(
Ys
)
>
0
,
"not enough argument"
);
return
max
(
x
,
max
(
ys
...));
}
template
<
typename
T
>
CK_TILE_HOST_DEVICE
constexpr
T
min
(
T
x
)
{
return
x
;
}
template
<
typename
T
>
CK_TILE_HOST_DEVICE
constexpr
T
min
(
T
x
)
{
return
x
;
}
template
<
typename
T
>
CK_TILE_HOST
constexpr
T
min
(
T
x
,
T
y
)
{
return
x
<
y
?
x
:
y
;
template
<
typename
T
>
CK_TILE_HOST
constexpr
T
min
(
T
x
,
T
y
)
{
return
x
<
y
?
x
:
y
;
}
template
<
typename
T
>
CK_TILE_DEVICE
constexpr
T
min
(
T
x
,
T
y
)
{
return
x
<
y
?
x
:
y
;
template
<
typename
T
>
CK_TILE_DEVICE
constexpr
T
min
(
T
x
,
T
y
)
{
return
x
<
y
?
x
:
y
;
}
template
<
>
CK_TILE_DEVICE
float
min
(
float
x
,
float
y
)
{
return
__builtin_fminf
(
x
,
y
);
template
<
>
CK_TILE_DEVICE
float
min
(
float
x
,
float
y
)
{
return
__builtin_fminf
(
x
,
y
);
}
template
<
>
CK_TILE_DEVICE
double
min
(
double
x
,
double
y
)
{
return
__builtin_fmin
(
x
,
y
);
template
<
>
CK_TILE_DEVICE
double
min
(
double
x
,
double
y
)
{
return
__builtin_fmin
(
x
,
y
);
}
template
<
typename
X
,
typename
...
Ys
>
CK_TILE_HOST_DEVICE
constexpr
auto
min
(
X
x
,
Ys
...
ys
)
{
static_assert
(
sizeof
...(
Ys
)
>
0
,
"not enough argument"
);
return
min
(
x
,
min
(
ys
...));
}
CK_TILE_HOST_DEVICE
constexpr
auto
min
(
X
x
,
Ys
...
ys
)
{
static_assert
(
sizeof
...(
Ys
)
>
0
,
"not enough argument"
);
return
min
(
x
,
min
(
ys
...));
}
}
// namespace ck_tile
src/tl_templates/dcu_hip/debug.h
View file @
8f4628e0
...
...
@@ -189,4 +189,3 @@ __device__ void debug_print_buffer_value<double>(const char *msg,
(
int
)
threadIdx
.
x
,
(
int
)
threadIdx
.
y
,
(
int
)
threadIdx
.
z
,
safe_buf_name
,
index
,
var
);
}
src/tl_templates/dcu_hip/gemm.h
View file @
8f4628e0
...
...
@@ -69,7 +69,7 @@ template <int M, int N, int K, int num_warp_n, int num_warp_m, bool TransposeA,
typename
B_type
,
typename
C_type
,
typename
AccDataType
=
float
>
class
GemmTensorOp
{
public:
//static_assert(!clear_accum, "clear_accum=true is not supported yet");
//
static_assert(!clear_accum, "clear_accum=true is not supported yet");
static
constexpr
int
micro_size_x
=
16
;
static
constexpr
int
micro_size_y
=
16
;
...
...
@@ -156,8 +156,8 @@ public:
C_type
*
C_local
)
{
auto
tid
=
threadIdx
.
x
;
auto
warp_id
=
tid
/
warp_size
;
auto
warp_
n
=
warp_id
/
block_
row
_warps
;
auto
warp_
m
=
warp_id
%
block_
row
_warps
;
auto
warp_
m
=
warp_id
/
block_
col
_warps
;
auto
warp_
n
=
warp_id
%
block_
col
_warps
;
auto
warp_row_tiles
=
warp_rows
*
micro_size_x
;
auto
warp_col_tiles
=
warp_cols
*
micro_size_y
;
...
...
@@ -165,8 +165,8 @@ public:
auto
tx
=
lane_id
;
auto
alane_id
=
lane_id
;
auto
blane_id
=
((
lane_id
&
15
)
>>
2
)
+
((
lane_id
&
3
)
<<
2
)
+
((
lane_id
>>
4
)
<<
4
);
auto
blane_id
=
((
lane_id
&
15
)
>>
2
)
+
((
lane_id
&
3
)
<<
2
)
+
((
lane_id
>>
4
)
<<
4
);
constexpr
auto
local_size_a
=
(
micro_size_x
*
micro_size_k
)
/
warp_size
;
constexpr
auto
local_size_b
=
(
micro_size_y
*
micro_size_k
)
/
warp_size
;
...
...
@@ -186,15 +186,14 @@ public:
for
(
int
local_id
=
0
;
local_id
<
(
kPack
*
local_size_b
);
local_id
++
)
{
if
constexpr
(
TransposeB
)
{
auto
[
row
,
col
]
=
reverse_index_map
(
blane_id
,
local_id
);
B_local
[
i
*
kPack
*
local_size_b
+
local_id
]
=
B_local
[
i
*
kPack
*
local_size_b
+
local_id
]
=
B_shared
[
make_swizzle_layout
<
last_dim_b
,
sizeof
(
B_type
)
>
(
l
+
row
,
r
+
col
)];
}
else
{
auto
[
row
,
col
]
=
reverse_index_map_transposed
(
blane_id
,
local_id
);
B_local
[
i
*
kPack
*
local_size_b
+
local_id
]
=
B_local
[
i
*
kPack
*
local_size_b
+
local_id
]
=
B_shared
[
make_swizzle_layout
<
last_dim_b
,
sizeof
(
B_type
)
>
(
r
+
row
,
l
+
col
)];
}
}
}
...
...
@@ -205,12 +204,12 @@ public:
for
(
int
local_id
=
0
;
local_id
<
(
kPack
*
local_size_a
);
local_id
++
)
{
if
constexpr
(
TransposeA
)
{
auto
[
row
,
col
]
=
reverse_index_map_transposed
(
alane_id
,
local_id
);
A_local
[
j
*
kPack
*
local_size_a
+
local_id
]
=
A_local
[
j
*
kPack
*
local_size_a
+
local_id
]
=
A_shared
[
make_swizzle_layout
<
last_dim_a
,
sizeof
(
A_type
)
>
(
r
+
row
,
l
+
col
)];
}
else
{
auto
[
row
,
col
]
=
reverse_index_map
(
alane_id
,
local_id
);
A_local
[
j
*
kPack
*
local_size_a
+
local_id
]
=
A_local
[
j
*
kPack
*
local_size_a
+
local_id
]
=
A_shared
[
make_swizzle_layout
<
last_dim_a
,
sizeof
(
A_type
)
>
(
l
+
row
,
r
+
col
)];
}
...
...
@@ -237,8 +236,8 @@ public:
C_type
*
C_local
)
{
auto
tid
=
threadIdx
.
x
;
auto
warp_id
=
tid
/
warp_size
;
auto
warp_
n
=
warp_id
/
block_
row
_warps
;
auto
warp_
m
=
warp_id
%
block_
row
_warps
;
auto
warp_
m
=
warp_id
/
block_
col
_warps
;
auto
warp_
n
=
warp_id
%
block_
col
_warps
;
auto
warp_row_tiles
=
warp_rows
*
micro_size_x
;
auto
warp_col_tiles
=
warp_cols
*
micro_size_y
;
...
...
@@ -246,7 +245,8 @@ public:
auto
tx
=
lane_id
;
auto
alane_id
=
lane_id
;
auto
blane_id
=
((
lane_id
&
15
)
>>
2
)
+
((
lane_id
&
3
)
<<
2
)
+
((
lane_id
>>
4
)
<<
4
);
auto
blane_id
=
((
lane_id
&
15
)
>>
2
)
+
((
lane_id
&
3
)
<<
2
)
+
((
lane_id
>>
4
)
<<
4
);
constexpr
auto
local_size_a
=
(
micro_size_x
*
micro_size_k
)
/
warp_size
;
constexpr
auto
local_size_b
=
(
micro_size_y
*
micro_size_k
)
/
warp_size
;
...
...
@@ -265,12 +265,12 @@ public:
for
(
int
local_id
=
0
;
local_id
<
(
kPack
*
local_size_b
);
local_id
++
)
{
if
constexpr
(
TransposeB
)
{
auto
[
row
,
col
]
=
reverse_index_map
(
blane_id
,
local_id
);
B_local
[
i
*
kPack
*
local_size_b
+
local_id
]
=
B_local
[
i
*
kPack
*
local_size_b
+
local_id
]
=
B_shared
[
make_swizzle_layout
<
last_dim_b
,
sizeof
(
B_type
)
>
(
l
+
row
,
r
+
col
)];
}
else
{
auto
[
row
,
col
]
=
reverse_index_map_transposed
(
blane_id
,
local_id
);
B_local
[
i
*
kPack
*
local_size_b
+
local_id
]
=
B_local
[
i
*
kPack
*
local_size_b
+
local_id
]
=
B_shared
[
make_swizzle_layout
<
last_dim_b
,
sizeof
(
B_type
)
>
(
r
+
row
,
l
+
col
)];
}
...
...
@@ -321,4 +321,3 @@ TL_DEVICE void gemm_rs(A_type *pA, B_type *pB, C_type *accum) {
}
}
// namespace tl
src/tl_templates/dcu_hip/hip_fp8.h
View file @
8f4628e0
...
...
@@ -72,4 +72,3 @@ __device__ fp8_e4_8_t make_fp8_e4_8_t(fp8_e4_t x, fp8_e4_t y, fp8_e4_t z,
res
.
y
=
*
reinterpret_cast
<
fp8_e4_4_t
*>
(
&
b
);
return
res
;
}
src/tl_templates/dcu_hip/reduce.h
View file @
8f4628e0
...
...
@@ -22,14 +22,11 @@ struct MinOp {
}
};
// Detect half types
template
<
typename
T
>
struct
is_half_type
:
std
::
false_type
{};
template
<
typename
T
>
struct
is_half_type
:
std
::
false_type
{};
template
<
>
struct
is_half_type
<
__half
>
:
std
::
true_type
{};
template
<
>
struct
is_half_type
<
__half
>
:
std
::
true_type
{};
template
<
>
struct
is_half_type
<
_Float16
>
:
std
::
true_type
{};
template
<
>
struct
is_half_type
<
_Float16
>
:
std
::
true_type
{};
template
<
typename
T
>
inline
constexpr
bool
is_half_v
=
is_half_type
<
std
::
decay_t
<
T
>>::
value
;
...
...
@@ -56,7 +53,10 @@ struct AllReduce {
if
constexpr
(
std
::
is_same_v
<
std
::
decay_t
<
T
>
,
__half
>
)
{
x_raw
=
__half_as_ushort
(
x
);
}
else
{
// _Float16
union
{
_Float16
f
;
unsigned
short
s
;
}
u
;
union
{
_Float16
f
;
unsigned
short
s
;
}
u
;
u
.
f
=
x
;
x_raw
=
u
.
s
;
}
...
...
@@ -67,7 +67,10 @@ struct AllReduce {
if
constexpr
(
std
::
is_same_v
<
std
::
decay_t
<
T
>
,
__half
>
)
{
shuffled_x
=
__ushort_as_half
(
shuffled_raw
);
}
else
{
// _Float16
union
{
unsigned
short
s
;
_Float16
f
;
}
u
;
union
{
unsigned
short
s
;
_Float16
f
;
}
u
;
u
.
s
=
shuffled_raw
;
shuffled_x
=
u
.
f
;
}
...
...
@@ -116,7 +119,7 @@ template <int threads, int Axis = 0, bool reverse = false> struct CumSum2D {
T
val
=
(
col
<
W
)
?
src
[
real_row
*
W
+
real_col
]
:
(
T
)
0
;
#pragma unroll
#pragma unroll
for
(
int
off
=
1
;
off
<
SEG
;
off
<<=
1
)
{
T
n
=
(
T
)
__shfl_down_sync
(
MASK
,
val
,
off
);
if
(
lane
<
SEG
-
off
)
...
...
@@ -142,7 +145,7 @@ template <int threads, int Axis = 0, bool reverse = false> struct CumSum2D {
T
val
=
(
col
<
W
)
?
src
[
real_row
*
W
+
real_col
]
:
(
T
)
0
;
#pragma unroll
#pragma unroll
for
(
int
off
=
1
;
off
<
SEG
;
off
<<=
1
)
{
T
n
=
(
T
)
__shfl_up_sync
(
MASK
,
val
,
off
);
if
(
lane
>=
off
)
...
...
@@ -164,4 +167,3 @@ template <int threads, int Axis = 0, bool reverse = false> struct CumSum2D {
}
};
}
// namespace tl
src/tl_templates/dcu_hip/threadblock_swizzle.h
View file @
8f4628e0
...
...
@@ -43,4 +43,3 @@ template <int panel_width> TL_DEVICE dim3 rasterization2DColumn() {
}
}
// namespace tl
testing/python/dcu/test_tilelang_gemm_mmac_intrinsic.py
View file @
8f4628e0
...
...
@@ -12,6 +12,7 @@ from tilelang.transform import simplify_prim_func
tilelang
.
testing
.
set_random_seed
(
0
)
tilelang
.
disable_cache
()
def
make_swizzle_layout
(
shared_buf
):
dtype
=
shared_buf
.
dtype
shape
=
shared_buf
.
shape
...
...
@@ -63,7 +64,7 @@ def tl_matmul(
chunk
=
32
*
k_pack
shared_scope
=
"shared"
cache_write_shared
=
False
#
cache_write_shared = False
block_M
=
block_row_warps
*
warp_row_tiles
block_N
=
block_col_warps
*
warp_col_tiles
...
...
@@ -171,7 +172,7 @@ def tl_matmul(
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
C
[
by
*
block_M
+
i
,
bx
*
block_N
+
j
]
=
C_shared
[
j
//
micro_size_y
,
i
//
micro_size_x
,
i
//
micro_size_x
,
i
%
micro_size_x
,
j
%
micro_size_y
,
]
...
...
tilelang/intrinsics/mmac_macro_generator.py
View file @
8f4628e0
from
__future__
import
annotations
from
tilelang
import
tvm
as
tvm
import
tilelang.language
as
T
from
typing
import
Tuple
from
tvm
import
DataType
from
tvm.tir
import
PrimExpr
from
tvm.runtime
import
convert
from
typing
import
Optional
from
.utils
import
(
mfma_store_index_map
,)
lift
=
convert
class
MatrixCoreIntrinEmitter
(
object
)
:
class
MatrixCoreIntrinEmitter
:
"""
To eliminate Python syntax within TIR Macro.
"""
...
...
@@ -51,9 +50,9 @@ class MatrixCoreIntrinEmitter(object):
chunk
:
int
=
16
,
reduce_k
:
int
=
1
,
num_elems_per_byte
:
int
=
1
,
k_pack
:
Optional
[
int
]
=
None
,
is_m_first
:
Optional
[
bool
]
=
False
,
b_preshuffle
:
Optional
[
bool
]
=
False
,
k_pack
:
int
|
None
=
None
,
is_m_first
:
bool
|
None
=
False
,
b_preshuffle
:
bool
|
None
=
False
,
):
self
.
a_dtype
=
a_dtype
self
.
b_dtype
=
b_dtype
...
...
@@ -119,7 +118,7 @@ class MatrixCoreIntrinEmitter(object):
"float16"
:
"f16"
,
"float32"
:
"f32"
,
"int8"
:
"i8"
,
"bfloat16"
:
"bf16"
"bfloat16"
:
"bf16"
}[
in_dtype
]
self
.
mmac_suffix
=
f
"
{
out_dtype_abbrv
}
_
{
M_DIM
}
x
{
N_DIM
}
x
{
k_dim
}{
in_dtype_abbrv
}
"
...
...
@@ -129,15 +128,15 @@ class MatrixCoreIntrinEmitter(object):
self
.
micro_size_y
=
n_dim
self
.
micro_size_k
=
k_dim
def
_initialize_k_pack
(
self
,
k_pack
:
Optional
[
int
]
=
None
):
def
_initialize_k_pack
(
self
,
k_pack
:
int
|
None
=
None
):
if
k_pack
is
not
None
:
self
.
k_pack
=
k_pack
def
_initialize_is_m_first
(
self
,
is_m_first
:
Optional
[
bool
]
=
False
):
def
_initialize_is_m_first
(
self
,
is_m_first
:
bool
|
None
=
False
):
if
is_m_first
is
not
None
:
self
.
is_m_first
=
is_m_first
def
_initialize_b_preshuffle
(
self
,
b_preshuffle
:
Optional
[
bool
]
=
False
):
def
_initialize_b_preshuffle
(
self
,
b_preshuffle
:
bool
|
None
=
False
):
if
b_preshuffle
is
not
None
:
self
.
b_preshuffle
=
b_preshuffle
...
...
@@ -197,7 +196,7 @@ class MatrixCoreIntrinEmitter(object):
def
extract_thread_binding
(
self
,
thread_id
,
is_m_first
=
None
)
->
T
uple
[
PrimExpr
,
PrimExpr
,
PrimExpr
]:
is_m_first
=
None
)
->
t
uple
[
PrimExpr
,
PrimExpr
,
PrimExpr
]:
'''
is_m_first: True if the thread binding is in the form of (tx, warp_n, warp_m)
which represents [warp_size, block_row_warps (split n), block_col_warps (split m)]
...
...
@@ -290,7 +289,9 @@ class MatrixCoreIntrinEmitter(object):
if
is_transposed
:
for
j
in
T
.
serial
(
warp_cols
):
for
local_id
in
T
.
vectorized
(
k_pack
*
local_size_b
):
row
,
col
=
T
.
meta_var
(
reverse_index_map
((
tx
&
15
)
//
4
+
(
tx
&
3
)
*
4
+
(
tx
//
16
)
*
16
,
local_id
))
row
,
col
=
T
.
meta_var
(
reverse_index_map
((
tx
&
15
)
//
4
+
(
tx
&
3
)
*
4
+
(
tx
//
16
)
*
16
,
local_id
))
l
,
r
=
(
warp_n
*
warp_col_tiles
+
j
*
micro_size_y
,
rk
*
chunk
+
ki
*
(
k_pack
*
micro_size_k
),
...
...
@@ -301,7 +302,9 @@ class MatrixCoreIntrinEmitter(object):
else
:
for
j
in
T
.
serial
(
warp_cols
):
for
local_id
in
T
.
vectorized
(
k_pack
*
local_size_b
):
row
,
col
=
T
.
meta_var
(
reverse_index_map
((
tx
&
15
)
//
4
+
(
tx
&
3
)
*
4
+
(
tx
//
16
)
*
16
,
local_id
))
row
,
col
=
T
.
meta_var
(
reverse_index_map
((
tx
&
15
)
//
4
+
(
tx
&
3
)
*
4
+
(
tx
//
16
)
*
16
,
local_id
))
l
,
r
=
(
rk
*
chunk
+
ki
*
(
k_pack
*
micro_size_k
),
warp_n
*
warp_col_tiles
+
j
*
micro_size_y
,
...
...
@@ -372,13 +375,13 @@ class MatrixCoreIntrinEmitter(object):
row
,
col
=
T
.
meta_var
(
mfma_store_index_map
(
tx
,
local_id
))
if
C_buf_dims
==
2
:
C_buf
[(
warp_m
*
warp_rows
+
i
)
*
M_DIM
+
row
,
(
warp_n
*
warp_cols
+
j
)
*
N_DIM
+
col
]
=
C_local_buf
[
j
*
(
warp_rows
*
local_size_out
)
+
i
*
local_size_out
+
local_id
]
(
warp_n
*
warp_cols
+
j
)
*
N_DIM
+
col
]
=
C_local_buf
[
j
*
(
warp_rows
*
local_size_out
)
+
i
*
local_size_out
+
local_id
]
else
:
C_buf
[
warp_n
*
warp_cols
+
j
,
warp_m
*
warp_rows
+
i
,
row
,
col
]
=
C_local_buf
[
j
*
warp_rows
*
local_size_out
+
i
*
local_size_out
+
local_id
]
col
]
=
C_local_buf
[
j
*
warp_rows
*
local_size_out
+
i
*
local_size_out
+
local_id
]
@
T
.
macro
def
_warp_stmatrix_global
(
C_local_buf
,
C_buf
,
thread_binding
):
...
...
@@ -412,10 +415,10 @@ class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter):
chunk
:
int
=
16
,
reduce_k
:
int
=
1
,
num_elems_per_byte
:
int
=
1
,
k_pack
:
Optional
[
int
]
=
None
,
is_m_first
:
Optional
[
bool
]
=
False
,
a_preshuffle
:
Optional
[
bool
]
=
False
,
b_preshuffle
:
Optional
[
bool
]
=
False
,
k_pack
:
int
|
None
=
None
,
is_m_first
:
bool
|
None
=
False
,
a_preshuffle
:
bool
|
None
=
False
,
b_preshuffle
:
bool
|
None
=
False
,
):
self
.
a_dtype
=
a_dtype
...
...
@@ -579,7 +582,9 @@ class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter):
if
is_transposed
:
for
j
in
T
.
serial
(
warp_cols
):
for
local_id
in
T
.
vectorized
(
k_pack
*
local_size_b
):
row
,
col
=
T
.
meta_var
(
reverse_index_map
(((
tx
&
15
)
>>
2
)
+
((
tx
&
3
)
<<
2
)
+
((
tx
>>
4
)
<<
4
),
local_id
))
row
,
col
=
T
.
meta_var
(
reverse_index_map
(((
tx
&
15
)
>>
2
)
+
((
tx
&
3
)
<<
2
)
+
((
tx
>>
4
)
<<
4
),
local_id
))
l
,
r
=
(
warp_n
*
warp_cols
+
j
,
rk
*
(
chunk
//
micro_size_k
)
+
ki
,
...
...
@@ -589,7 +594,9 @@ class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter):
else
:
for
j
in
T
.
serial
(
warp_cols
):
for
local_id
in
T
.
vectorized
(
k_pack
*
local_size_b
):
row
,
col
=
T
.
meta_var
(
reverse_index_map
(((
tx
&
15
)
>>
2
)
+
((
tx
&
3
)
<<
2
)
+
((
tx
>>
4
)
<<
4
),
local_id
))
row
,
col
=
T
.
meta_var
(
reverse_index_map
(((
tx
&
15
)
>>
2
)
+
((
tx
&
3
)
<<
2
)
+
((
tx
>>
4
)
<<
4
),
local_id
))
l
,
r
=
(
rk
*
(
chunk
//
micro_size_k
)
+
ki
,
warp_n
*
warp_cols
+
j
,
...
...
@@ -600,4 +607,3 @@ class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter):
return
_warp_ldmatrix_b_global
(
B_local_buf
,
B_buf
,
ki
,
thread_binding
,
rk
)
if
is_global
else
_warp_ldmatrix_b_shared
(
B_local_buf
,
B_buf
,
ki
,
thread_binding
,
rk
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment