Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
549416f7
Commit
549416f7
authored
Jan 11, 2025
by
LeiWang1999
Browse files
Merge branch 'main' of
https://github.com/microsoft/TileLang
into main
parents
4d63633a
7fad4e88
Changes
90
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
44 additions
and
78 deletions
+44
-78
tilelang/language/customize.py
tilelang/language/customize.py
+2
-6
tilelang/language/gemm.py
tilelang/language/gemm.py
+1
-0
tilelang/language/kernel.py
tilelang/language/kernel.py
+1
-0
tilelang/language/pipeline.py
tilelang/language/pipeline.py
+1
-3
tilelang/language/reduce.py
tilelang/language/reduce.py
+4
-9
tilelang/layout/swizzle.py
tilelang/layout/swizzle.py
+1
-0
tilelang/primitives/gemm/__init__.py
tilelang/primitives/gemm/__init__.py
+6
-9
tilelang/primitives/gemm/base.py
tilelang/primitives/gemm/base.py
+12
-26
tilelang/primitives/gemm/gemm_mma.py
tilelang/primitives/gemm/gemm_mma.py
+14
-25
tilelang/primitives/utils.py
tilelang/primitives/utils.py
+2
-0
No files found.
tilelang/language/customize.py
View file @
549416f7
...
@@ -10,12 +10,8 @@ def atomic_add(dst, value):
...
@@ -10,12 +10,8 @@ def atomic_add(dst, value):
def
atomic_addx2
(
dst
,
value
):
def
atomic_addx2
(
dst
,
value
):
return
T
.
call_extern
(
return
T
.
call_extern
(
"handle"
,
"atomicAddx2"
,
T
.
address_of
(
dst
),
T
.
address_of
(
value
))
"handle"
,
"atomicAddx2"
,
T
.
address_of
(
dst
),
T
.
address_of
(
value
)
)
def
dp4a
(
A
,
B
,
C
):
def
dp4a
(
A
,
B
,
C
):
return
T
.
call_extern
(
return
T
.
call_extern
(
"handle"
,
"DP4A"
,
T
.
address_of
(
A
),
T
.
address_of
(
B
),
T
.
address_of
(
C
))
"handle"
,
"DP4A"
,
T
.
address_of
(
A
),
T
.
address_of
(
B
),
T
.
address_of
(
C
)
)
tilelang/language/gemm.py
View file @
549416f7
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
from
tvm
import
tir
from
tvm
import
tir
class
GemmWarpPolicy
:
class
GemmWarpPolicy
:
Square
=
0
Square
=
0
FullRow
=
1
FullRow
=
1
...
...
tilelang/language/kernel.py
View file @
549416f7
...
@@ -145,6 +145,7 @@ class KernelLaunchFrame(TIRFrame):
...
@@ -145,6 +145,7 @@ class KernelLaunchFrame(TIRFrame):
"""
"""
return
self
.
get_num_threads
()
return
self
.
get_num_threads
()
def
Kernel
(
def
Kernel
(
*
blocks
:
List
[
tir
.
PrimExpr
],
*
blocks
:
List
[
tir
.
PrimExpr
],
threads
:
Union
[
int
,
List
[
int
],
Tuple
]
=
128
,
threads
:
Union
[
int
,
List
[
int
],
Tuple
]
=
128
,
...
...
tilelang/language/pipeline.py
View file @
549416f7
...
@@ -45,6 +45,4 @@ def Pipelined(
...
@@ -45,6 +45,4 @@ def Pipelined(
if
group
is
None
:
if
group
is
None
:
group
=
[]
group
=
[]
# type: ignore[attr-defined] # pylint: disable=no-member
# type: ignore[attr-defined] # pylint: disable=no-member
return
_ffi_api
.
Pipelined
(
return
_ffi_api
.
Pipelined
(
start
,
stop
,
num_stages
,
order
,
stage
,
sync
,
group
)
start
,
stop
,
num_stages
,
order
,
stage
,
sync
,
group
)
tilelang/language/reduce.py
View file @
549416f7
...
@@ -4,9 +4,8 @@
...
@@ -4,9 +4,8 @@
from
tvm
import
tir
from
tvm
import
tir
def
reduce
(
buffer
:
tir
.
Buffer
,
out
:
tir
.
Buffer
,
reduce_type
:
str
,
dim
:
int
,
clear
:
bool
def
reduce
(
buffer
:
tir
.
Buffer
,
out
:
tir
.
Buffer
,
reduce_type
:
str
,
dim
:
int
,
clear
:
bool
):
):
buffer
=
buffer
.
access_ptr
(
"r"
)
buffer
=
buffer
.
access_ptr
(
"r"
)
out
=
out
.
access_ptr
(
"w"
)
out
=
out
.
access_ptr
(
"w"
)
return
tir
.
call_intrin
(
return
tir
.
call_intrin
(
...
@@ -20,9 +19,7 @@ def reduce(
...
@@ -20,9 +19,7 @@ def reduce(
)
)
def
reduce_max
(
def
reduce_max
(
buffer
:
tir
.
Buffer
,
out
:
tir
.
Buffer
,
dim
:
int
,
clear
:
bool
=
True
):
buffer
:
tir
.
Buffer
,
out
:
tir
.
Buffer
,
dim
:
int
,
clear
:
bool
=
True
):
"""Perform reduce max on input buffer, store the result to output buffer
"""Perform reduce max on input buffer, store the result to output buffer
Parameters
Parameters
...
@@ -42,9 +39,7 @@ def reduce_max(
...
@@ -42,9 +39,7 @@ def reduce_max(
return
reduce
(
buffer
,
out
,
"max"
,
dim
,
clear
)
return
reduce
(
buffer
,
out
,
"max"
,
dim
,
clear
)
def
reduce_min
(
def
reduce_min
(
buffer
:
tir
.
Buffer
,
out
:
tir
.
Buffer
,
dim
:
int
,
clear
:
bool
=
True
):
buffer
:
tir
.
Buffer
,
out
:
tir
.
Buffer
,
dim
:
int
,
clear
:
bool
=
True
):
return
reduce
(
buffer
,
out
,
"min"
,
dim
,
clear
)
return
reduce
(
buffer
,
out
,
"min"
,
dim
,
clear
)
...
...
tilelang/layout/swizzle.py
View file @
549416f7
...
@@ -6,6 +6,7 @@
...
@@ -6,6 +6,7 @@
import
tvm
import
tvm
from
tilelang
import
_ffi_api
from
tilelang
import
_ffi_api
def
make_swizzled_layout
(
buffer
:
tvm
.
tir
.
Buffer
):
def
make_swizzled_layout
(
buffer
:
tvm
.
tir
.
Buffer
):
assert
len
(
buffer
.
shape
)
==
2
assert
len
(
buffer
.
shape
)
==
2
return
_ffi_api
.
make_swizzled_layout
(
return
_ffi_api
.
make_swizzled_layout
(
...
...
tilelang/primitives/gemm/__init__.py
View file @
549416f7
...
@@ -6,8 +6,8 @@ from tvm import tir
...
@@ -6,8 +6,8 @@ from tvm import tir
from
tilelang.primitives.utils
import
is_local
,
is_fragment
,
is_shared
from
tilelang.primitives.utils
import
is_local
,
is_fragment
,
is_shared
from
tilelang.primitives.gemm.base
import
GemmWarpPolicy
from
tilelang.primitives.gemm.base
import
GemmWarpPolicy
from
tilelang.primitives.gemm.gemm_mma
import
(
from
tilelang.primitives.gemm.gemm_mma
import
(
GemmPrimitiveMMA
,
GemmPrimitiveMMA
,
)
)
def
gemm
(
def
gemm
(
A
:
tir
.
Buffer
,
A
:
tir
.
Buffer
,
...
@@ -24,16 +24,13 @@ def gemm(
...
@@ -24,16 +24,13 @@ def gemm(
k_pack
:
int
=
1
,
k_pack
:
int
=
1
,
):
):
assert
is_local
(
A
)
or
is_fragment
(
A
)
or
is_shared
(
A
),
(
assert
is_local
(
A
)
or
is_fragment
(
A
)
or
is_shared
(
A
),
(
f
"Expected A to be a local, fragment, or shared buffer, but got
{
A
.
scope
()
}
"
f
"Expected A to be a local, fragment, or shared buffer, but got
{
A
.
scope
()
}
"
)
)
assert
is_local
(
B
)
or
is_fragment
(
B
)
or
is_shared
(
B
),
(
assert
is_local
(
B
)
or
is_fragment
(
B
)
or
is_shared
(
B
),
(
f
"Expected B to be a local, fragment, or shared buffer, but got
{
B
.
scope
()
}
"
f
"Expected B to be a local, fragment, or shared buffer, but got
{
B
.
scope
()
}
"
)
)
assert
is_local
(
C
)
or
is_fragment
(
C
),
(
assert
is_local
(
C
)
or
is_fragment
(
C
),
(
f
"Expected C to be a local, fragment, but got
{
C
.
scope
()
}
"
f
"Expected C to be a local, fragment, but got
{
C
.
scope
()
}
"
)
)
# TODO(lei): Now we only support Nvidia GPUs
# TODO(lei): Now we only support Nvidia GPUs
# Must enhance the design to implement runtime lowering
# Must enhance the design to implement runtime lowering
# for different targets (hip mfma for example)
# for different targets (hip mfma for example)
return
GemmPrimitiveMMA
(
return
GemmPrimitiveMMA
(
A
=
A
,
A
=
A
,
...
...
tilelang/primitives/gemm/base.py
View file @
549416f7
...
@@ -7,6 +7,7 @@ from dataclasses import dataclass
...
@@ -7,6 +7,7 @@ from dataclasses import dataclass
from
typing
import
Optional
from
typing
import
Optional
from
tvm
import
tir
from
tvm
import
tir
class
GemmWarpPolicy
(
IntEnum
):
class
GemmWarpPolicy
(
IntEnum
):
"""
"""
Enumeration for GEMM Warp Partitioning Policies.
Enumeration for GEMM Warp Partitioning Policies.
...
@@ -89,16 +90,12 @@ class GemmWarpPolicy(IntEnum):
...
@@ -89,16 +90,12 @@ class GemmWarpPolicy(IntEnum):
if
self
.
is_full_row
():
if
self
.
is_full_row
():
# FullRow policy: Allocate all warps to rows.
# FullRow policy: Allocate all warps to rows.
m_warp
=
num_warps
m_warp
=
num_warps
assert
(
assert
(
M
%
num_warps
==
0
),
"M must be divisible by num_warps for FullRow policy"
M
%
num_warps
==
0
),
"M must be divisible by num_warps for FullRow policy"
elif
self
.
is_full_col
():
elif
self
.
is_full_col
():
# FullCol policy: Allocate all warps to columns.
# FullCol policy: Allocate all warps to columns.
n_warp
=
num_warps
n_warp
=
num_warps
assert
(
assert
(
N
%
num_warps
==
0
),
"N must be divisible by num_warps for FullCol policy"
N
%
num_warps
==
0
),
"N must be divisible by num_warps for FullCol policy"
elif
self
.
is_square
():
elif
self
.
is_square
():
# Square policy: Try to balance warps across rows and columns.
# Square policy: Try to balance warps across rows and columns.
...
@@ -136,7 +133,7 @@ class GemmBaseParams:
...
@@ -136,7 +133,7 @@ class GemmBaseParams:
A
:
tir
.
Buffer
A
:
tir
.
Buffer
B
:
tir
.
Buffer
B
:
tir
.
Buffer
C
:
tir
.
Buffer
C
:
tir
.
Buffer
transpose_A
:
bool
=
False
transpose_A
:
bool
=
False
transpose_B
:
bool
=
False
transpose_B
:
bool
=
False
block_row_warps
:
Optional
[
int
]
=
None
block_row_warps
:
Optional
[
int
]
=
None
...
@@ -148,7 +145,7 @@ class GemmBaseParams:
...
@@ -148,7 +145,7 @@ class GemmBaseParams:
k_pack
:
int
=
1
k_pack
:
int
=
1
def
get_warp_size
(
self
)
->
int
:
def
get_warp_size
(
self
)
->
int
:
# must rewrite to 64 if the target
# must rewrite to 64 if the target
# is cdna mfma
# is cdna mfma
return
32
return
32
...
@@ -168,7 +165,6 @@ class GemmBaseParams:
...
@@ -168,7 +165,6 @@ class GemmBaseParams:
"k_pack"
:
self
.
k_pack
,
"k_pack"
:
self
.
k_pack
,
}
}
def
infer_block_partition
(
self
,
threads
:
Optional
[
int
])
->
None
:
def
infer_block_partition
(
self
,
threads
:
Optional
[
int
])
->
None
:
"""
"""
Infer and set block partition parameters (e.g., block_row_warps,
Infer and set block partition parameters (e.g., block_row_warps,
...
@@ -210,19 +206,13 @@ class GemmBaseParams:
...
@@ -210,19 +206,13 @@ class GemmBaseParams:
# Determine whether block partition parameters need to be inferred
# Determine whether block partition parameters need to be inferred
require_infer
=
(
require_infer
=
(
block_row_warps
is
None
block_row_warps
is
None
or
block_col_warps
is
None
or
warp_row_tiles
is
None
or
or
block_col_warps
is
None
warp_col_tiles
is
None
or
chunk
is
None
)
or
warp_row_tiles
is
None
or
warp_col_tiles
is
None
or
chunk
is
None
)
A_shape
,
B_shape
=
A
.
shape
,
B
.
shape
A_shape
,
B_shape
=
A
.
shape
,
B
.
shape
if
require_infer
:
if
require_infer
:
assert
(
assert
(
threads
is
not
None
),
"threads must be provided for auto inference"
threads
is
not
None
),
"threads must be provided for auto inference"
# Auto-inference only supports 2D matrix multiplication
# Auto-inference only supports 2D matrix multiplication
assert
(
assert
(
len
(
A_shape
)
==
2
and
len
(
B_shape
)
==
2
len
(
A_shape
)
==
2
and
len
(
B_shape
)
==
2
...
@@ -241,28 +231,24 @@ class GemmBaseParams:
...
@@ -241,28 +231,24 @@ class GemmBaseParams:
# Infer block partition using a user-specified policy
# Infer block partition using a user-specified policy
block_row_warps
,
block_col_warps
=
policy
.
compute_warp_partition
(
block_row_warps
,
block_col_warps
=
policy
.
compute_warp_partition
(
block_M
,
block_N
,
num_warps
block_M
,
block_N
,
num_warps
)
)
warp_row_tiles
=
block_M
//
block_row_warps
warp_row_tiles
=
block_M
//
block_row_warps
warp_col_tiles
=
block_N
//
block_col_warps
warp_col_tiles
=
block_N
//
block_col_warps
chunk
=
int
(
AK
)
chunk
=
int
(
AK
)
# rewrite the values
# rewrite the values
self
.
block_row_warps
=
block_row_warps
self
.
block_row_warps
=
block_row_warps
self
.
block_col_warps
=
block_col_warps
self
.
block_col_warps
=
block_col_warps
self
.
warp_row_tiles
=
warp_row_tiles
self
.
warp_row_tiles
=
warp_row_tiles
self
.
warp_col_tiles
=
warp_col_tiles
self
.
warp_col_tiles
=
warp_col_tiles
self
.
chunk
=
chunk
self
.
chunk
=
chunk
@
property
@
property
def
class_attributes
(
self
):
def
class_attributes
(
self
):
return
self
.
params_as_dict
()
return
self
.
params_as_dict
()
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
cls_name
=
self
.
__class__
.
__name__
cls_name
=
self
.
__class__
.
__name__
fields
=
self
.
class_attributes
fields
=
self
.
class_attributes
field_str
=
", "
.
join
(
field_str
=
", "
.
join
(
f
"
{
key
}
=
{
value
!
r
}
"
for
key
,
value
in
fields
.
items
())
f
"
{
key
}
=
{
value
!
r
}
"
for
key
,
value
in
fields
.
items
()
)
return
f
"
{
cls_name
}
(
{
field_str
}
)"
return
f
"
{
cls_name
}
(
{
field_str
}
)"
tilelang/primitives/gemm/gemm_mma.py
View file @
549416f7
...
@@ -11,6 +11,7 @@ from tilelang.primitives.utils import is_fragment, array_reduce
...
@@ -11,6 +11,7 @@ from tilelang.primitives.utils import is_fragment, array_reduce
from
tilelang.primitives.gemm.base
import
GemmBaseParams
from
tilelang.primitives.gemm.base
import
GemmBaseParams
from
tilelang.intrinsics.mma_macro_generator
import
TensorCoreIntrinEmitter
from
tilelang.intrinsics.mma_macro_generator
import
TensorCoreIntrinEmitter
# TODO(lei): Implement GEMM_SR, GEMM_RS, GEMM_RR
# TODO(lei): Implement GEMM_SR, GEMM_RS, GEMM_RR
@
dataclass
@
dataclass
class
GemmPrimitiveMMA
(
GemmBaseParams
):
class
GemmPrimitiveMMA
(
GemmBaseParams
):
...
@@ -35,7 +36,7 @@ class GemmPrimitiveMMA(GemmBaseParams):
...
@@ -35,7 +36,7 @@ class GemmPrimitiveMMA(GemmBaseParams):
B
:
tir
.
Buffer
,
B
:
tir
.
Buffer
,
C
:
tir
.
Buffer
,
C
:
tir
.
Buffer
,
mma_emitter
:
TensorCoreIntrinEmitter
,
mma_emitter
:
TensorCoreIntrinEmitter
,
)
->
tir
.
PrimExpr
:
)
->
tir
.
PrimExpr
:
in_dtype
=
self
.
in_dtype
in_dtype
=
self
.
in_dtype
warp_rows
=
mma_emitter
.
warp_rows
warp_rows
=
mma_emitter
.
warp_rows
...
@@ -50,9 +51,7 @@ class GemmPrimitiveMMA(GemmBaseParams):
...
@@ -50,9 +51,7 @@ class GemmPrimitiveMMA(GemmBaseParams):
c_is_fragment
=
is_fragment
(
C
)
c_is_fragment
=
is_fragment
(
C
)
@
T
.
macro
@
T
.
macro
def
_gemm_rsr
(
def
_gemm_rsr
(
A_local
:
tir
.
Buffer
,
B_shared
:
tir
.
Buffer
,
C_local
:
tir
.
Buffer
)
->
None
:
A_local
:
tir
.
Buffer
,
B_shared
:
tir
.
Buffer
,
C_local
:
tir
.
Buffer
)
->
None
:
"""
"""
The inner macro that loads data from shared buffers A_shared and
The inner macro that loads data from shared buffers A_shared and
B_shared into local fragments, then issues Tensor Core mma ops,
B_shared into local fragments, then issues Tensor Core mma ops,
...
@@ -63,18 +62,14 @@ class GemmPrimitiveMMA(GemmBaseParams):
...
@@ -63,18 +62,14 @@ class GemmPrimitiveMMA(GemmBaseParams):
thread_bindings
=
T
.
thread_binding
(
0
,
threads
,
"threadIdx.x"
)
thread_bindings
=
T
.
thread_binding
(
0
,
threads
,
"threadIdx.x"
)
if
a_is_fragment
:
if
a_is_fragment
:
# Annotate layout for A_local if it is a fragment.
# Annotate layout for A_local if it is a fragment.
T
.
annotate_layout
(
T
.
annotate_layout
({
{
A_local
:
mma_emitter
.
make_mma_load_layout
(
A_local
,
"A"
),
A_local
:
mma_emitter
.
make_mma_load_layout
(
A_local
,
"A"
),
})
}
)
if
c_is_fragment
:
if
c_is_fragment
:
# Annotate layout for C_local if it is a fragment.
# Annotate layout for C_local if it is a fragment.
T
.
annotate_layout
(
T
.
annotate_layout
({
{
C_local
:
mma_emitter
.
make_mma_store_layout
(
C_local
),
C_local
:
mma_emitter
.
make_mma_store_layout
(
C_local
),
})
}
)
for
ki
in
T
.
serial
(
0
,
(
block_K
//
micro_size_k
)):
for
ki
in
T
.
serial
(
0
,
(
block_K
//
micro_size_k
)):
...
@@ -101,7 +96,7 @@ class GemmPrimitiveMMA(GemmBaseParams):
...
@@ -101,7 +96,7 @@ class GemmPrimitiveMMA(GemmBaseParams):
B
:
tir
.
Buffer
,
B
:
tir
.
Buffer
,
C
:
tir
.
Buffer
,
C
:
tir
.
Buffer
,
mma_emitter
:
TensorCoreIntrinEmitter
,
mma_emitter
:
TensorCoreIntrinEmitter
,
)
->
tir
.
PrimExpr
:
)
->
tir
.
PrimExpr
:
raise
NotImplementedError
(
"GEMM_RSR is not implemented yet"
)
raise
NotImplementedError
(
"GEMM_RSR is not implemented yet"
)
def
gemm_ssr
(
def
gemm_ssr
(
...
@@ -147,9 +142,7 @@ class GemmPrimitiveMMA(GemmBaseParams):
...
@@ -147,9 +142,7 @@ class GemmPrimitiveMMA(GemmBaseParams):
c_is_fragment
=
is_fragment
(
C
)
c_is_fragment
=
is_fragment
(
C
)
@
T
.
macro
@
T
.
macro
def
_gemm_ssr
(
def
_gemm_ssr
(
A_shared
:
tir
.
Buffer
,
B_shared
:
tir
.
Buffer
,
C_local
:
tir
.
Buffer
)
->
None
:
A_shared
:
tir
.
Buffer
,
B_shared
:
tir
.
Buffer
,
C_local
:
tir
.
Buffer
)
->
None
:
"""
"""
The inner macro that loads data from shared buffers A_shared and
The inner macro that loads data from shared buffers A_shared and
B_shared into local fragments, then issues Tensor Core mma ops,
B_shared into local fragments, then issues Tensor Core mma ops,
...
@@ -162,13 +155,9 @@ class GemmPrimitiveMMA(GemmBaseParams):
...
@@ -162,13 +155,9 @@ class GemmPrimitiveMMA(GemmBaseParams):
if
c_is_fragment
:
if
c_is_fragment
:
# Annotate layout for C_local if it is a fragment.
# Annotate layout for C_local if it is a fragment.
T
.
annotate_layout
(
T
.
annotate_layout
({
{
C_local
:
mma_emitter
.
make_mma_store_layout
(
C_local
),
C_local
:
mma_emitter
.
make_mma_store_layout
(
})
C_local
),
}
)
for
ki
in
T
.
serial
(
0
,
(
block_K
//
micro_size_k
)):
for
ki
in
T
.
serial
(
0
,
(
block_K
//
micro_size_k
)):
# Load A into fragment
# Load A into fragment
...
...
tilelang/primitives/utils.py
View file @
549416f7
...
@@ -37,6 +37,7 @@ def is_shared(buffer: Buffer, allow_dynamic: bool = True) -> bool:
...
@@ -37,6 +37,7 @@ def is_shared(buffer: Buffer, allow_dynamic: bool = True) -> bool:
conditions
.
append
(
is_shared_dynamic
(
buffer
))
conditions
.
append
(
is_shared_dynamic
(
buffer
))
return
any
(
conditions
)
return
any
(
conditions
)
def
is_shared_dynamic
(
buffer
:
Buffer
)
->
bool
:
def
is_shared_dynamic
(
buffer
:
Buffer
)
->
bool
:
"""
"""
Check if the buffer is in the dynamic shared memory scope.
Check if the buffer is in the dynamic shared memory scope.
...
@@ -75,6 +76,7 @@ def is_fragment(buffer: Buffer) -> bool:
...
@@ -75,6 +76,7 @@ def is_fragment(buffer: Buffer) -> bool:
"""
"""
return
buffer
.
scope
().
startswith
(
"local.fragment"
)
return
buffer
.
scope
().
startswith
(
"local.fragment"
)
def
array_reduce
(
array
:
List
[
int
])
->
int
:
def
array_reduce
(
array
:
List
[
int
])
->
int
:
"""
"""
Reduce an array of integers to a single integer.
Reduce an array of integers to a single integer.
...
...
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment