Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
29051439
Unverified
Commit
29051439
authored
Dec 12, 2025
by
Lei Wang
Committed by
GitHub
Dec 12, 2025
Browse files
[Lint] Phaseout Yapf format and embrace ruff format (#1417)
parent
e84b24bc
Changes
426
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
479 additions
and
652 deletions
+479
-652
tilelang/contrib/hipcc.py
tilelang/contrib/hipcc.py
+2
-7
tilelang/contrib/nvcc.py
tilelang/contrib/nvcc.py
+12
-25
tilelang/contrib/nvrtc.py
tilelang/contrib/nvrtc.py
+10
-11
tilelang/contrib/rocm.py
tilelang/contrib/rocm.py
+5
-2
tilelang/engine/lower.py
tilelang/engine/lower.py
+17
-22
tilelang/engine/param.py
tilelang/engine/param.py
+3
-0
tilelang/engine/phase.py
tilelang/engine/phase.py
+6
-12
tilelang/env.py
tilelang/env.py
+33
-39
tilelang/intrinsics/mfma_layout.py
tilelang/intrinsics/mfma_layout.py
+9
-9
tilelang/intrinsics/mfma_macro_generator.py
tilelang/intrinsics/mfma_macro_generator.py
+88
-107
tilelang/intrinsics/mma_layout.py
tilelang/intrinsics/mma_layout.py
+11
-11
tilelang/intrinsics/mma_macro_generator.py
tilelang/intrinsics/mma_macro_generator.py
+50
-93
tilelang/intrinsics/mma_sm70_layout.py
tilelang/intrinsics/mma_sm70_layout.py
+3
-5
tilelang/intrinsics/mma_sm70_macro_generator.py
tilelang/intrinsics/mma_sm70_macro_generator.py
+21
-54
tilelang/intrinsics/mma_sp_layout.py
tilelang/intrinsics/mma_sp_layout.py
+9
-18
tilelang/intrinsics/mma_sp_macro_generator.py
tilelang/intrinsics/mma_sp_macro_generator.py
+41
-74
tilelang/intrinsics/tcgen05_macro_generator.py
tilelang/intrinsics/tcgen05_macro_generator.py
+49
-44
tilelang/intrinsics/utils.py
tilelang/intrinsics/utils.py
+1
-1
tilelang/intrinsics/wgmma_macro_generator.py
tilelang/intrinsics/wgmma_macro_generator.py
+93
-86
tilelang/ir.py
tilelang/ir.py
+16
-32
No files found.
Too many changes to show.
To preserve performance only
426 of 426+
files are displayed.
Plain diff
Email patch
tilelang/contrib/hipcc.py
View file @
29051439
...
@@ -16,12 +16,7 @@ from tvm.base import py_str
...
@@ -16,12 +16,7 @@ from tvm.base import py_str
from
tvm.contrib.rocm
import
get_rocm_arch
,
find_rocm_path
from
tvm.contrib.rocm
import
get_rocm_arch
,
find_rocm_path
def
compile_hip
(
code
,
def
compile_hip
(
code
,
target_format
=
"hsaco"
,
arch
=
None
,
options
=
None
,
path_target
=
None
,
verbose
=
False
):
target_format
=
"hsaco"
,
arch
=
None
,
options
=
None
,
path_target
=
None
,
verbose
=
False
):
"""Compile HIP code with hipcc.
"""Compile HIP code with hipcc.
Parameters
Parameters
...
@@ -61,7 +56,7 @@ def compile_hip(code,
...
@@ -61,7 +56,7 @@ def compile_hip(code,
file_target
=
path_target
if
path_target
else
temp_target
file_target
=
path_target
if
path_target
else
temp_target
cmd
=
[
"hipcc"
]
cmd
=
[
"hipcc"
]
cmd
+=
[
"-O3"
,
'
-c
'
]
cmd
+=
[
"-O3"
,
"
-c
"
]
if
isinstance
(
arch
,
str
):
if
isinstance
(
arch
,
str
):
cmd
+=
[
f
"--offload-arch=
{
arch
}
"
]
cmd
+=
[
f
"--offload-arch=
{
arch
}
"
]
if
target_format
==
"hsaco"
:
if
target_format
==
"hsaco"
:
...
...
tilelang/contrib/nvcc.py
View file @
29051439
# pylint: disable=invalid-name
# pylint: disable=invalid-name
# modified from apache tvm python/tvm/contrib/nvcc.py
# modified from apache tvm python/tvm/contrib/nvcc.py
"""Utility to invoke nvcc compiler in the system"""
"""Utility to invoke nvcc compiler in the system"""
from
__future__
import
annotations
from
__future__
import
annotations
import
os
import
os
...
@@ -18,12 +19,7 @@ from tvm.base import py_str
...
@@ -18,12 +19,7 @@ from tvm.base import py_str
from
tvm.contrib
import
utils
from
tvm.contrib
import
utils
def
compile_cuda
(
code
,
def
compile_cuda
(
code
,
target_format
=
"ptx"
,
arch
=
None
,
options
=
None
,
path_target
=
None
,
verbose
=
False
):
target_format
=
"ptx"
,
arch
=
None
,
options
=
None
,
path_target
=
None
,
verbose
=
False
):
"""Compile cuda code with NVCC from env.
"""Compile cuda code with NVCC from env.
Parameters
Parameters
...
@@ -67,7 +63,7 @@ def compile_cuda(code,
...
@@ -67,7 +63,7 @@ def compile_cuda(code,
temp_target
=
temp
.
relpath
(
f
"
{
file_name
}
.
{
target_format
}
"
)
temp_target
=
temp
.
relpath
(
f
"
{
file_name
}
.
{
target_format
}
"
)
pass_context
=
tvm
.
get_global_func
(
"transform.GetCurrentPassContext"
)()
pass_context
=
tvm
.
get_global_func
(
"transform.GetCurrentPassContext"
)()
kernels_output_dir
=
(
pass_context
.
config
.
get
(
"cuda.kernels_output_dir"
,
None
)
)
kernels_output_dir
=
pass_context
.
config
.
get
(
"cuda.kernels_output_dir"
,
None
)
if
kernels_output_dir
is
not
None
:
if
kernels_output_dir
is
not
None
:
if
not
os
.
path
.
isdir
(
kernels_output_dir
):
if
not
os
.
path
.
isdir
(
kernels_output_dir
):
os
.
makedirs
(
kernels_output_dir
)
os
.
makedirs
(
kernels_output_dir
)
...
@@ -114,10 +110,7 @@ def compile_cuda(code,
...
@@ -114,10 +110,7 @@ def compile_cuda(code,
print
(
py_str
(
out
))
print
(
py_str
(
out
))
if
proc
.
returncode
!=
0
:
if
proc
.
returncode
!=
0
:
msg
=
f
"
{
code
}
\n
"
\
msg
=
f
"
{
code
}
\n
Compilation error:
\n
{
py_str
(
out
)
}
\n
Command:
{
' '
.
join
(
cmd
)
}
\n
"
f
"Compilation error:
\n
"
\
f
"
{
py_str
(
out
)
}
\n
"
\
f
"Command:
{
' '
.
join
(
cmd
)
}
\n
"
raise
RuntimeError
(
msg
)
raise
RuntimeError
(
msg
)
with
open
(
file_target
,
"rb"
)
as
f
:
with
open
(
file_target
,
"rb"
)
as
f
:
...
@@ -165,6 +158,7 @@ def default_compile_options(compile_flags: list[str] | None = None) -> list[str]
...
@@ -165,6 +158,7 @@ def default_compile_options(compile_flags: list[str] | None = None) -> list[str]
# (e.g., multiple "-gencode" pairs or repeated "-Xcompiler" entries).
# (e.g., multiple "-gencode" pairs or repeated "-Xcompiler" entries).
if
compile_flags
:
if
compile_flags
:
import
shlex
import
shlex
for
flag
in
compile_flags
:
for
flag
in
compile_flags
:
# Split each string like a shell would, preserving quoted args
# Split each string like a shell would, preserving quoted args
tokens
=
shlex
.
split
(
flag
)
if
isinstance
(
flag
,
str
)
else
[
str
(
flag
)]
tokens
=
shlex
.
split
(
flag
)
if
isinstance
(
flag
,
str
)
else
[
str
(
flag
)]
...
@@ -172,9 +166,7 @@ def default_compile_options(compile_flags: list[str] | None = None) -> list[str]
...
@@ -172,9 +166,7 @@ def default_compile_options(compile_flags: list[str] | None = None) -> list[str]
return
options
return
options
def
get_ptx_from_source
(
code
:
str
,
def
get_ptx_from_source
(
code
:
str
,
compile_flags
:
list
[
str
]
|
None
=
None
,
verbose
:
bool
=
False
)
->
str
:
compile_flags
:
list
[
str
]
|
None
=
None
,
verbose
:
bool
=
False
)
->
str
:
"""
"""
Compile CUDA C++ source to PTX using NVCC and return as text.
Compile CUDA C++ source to PTX using NVCC and return as text.
...
@@ -212,9 +204,7 @@ def _find_tool(name: str) -> str | None:
...
@@ -212,9 +204,7 @@ def _find_tool(name: str) -> str | None:
return
None
return
None
def
get_sass_from_source
(
code
:
str
,
def
get_sass_from_source
(
code
:
str
,
compile_flags
:
list
[
str
]
|
None
=
None
,
verbose
:
bool
=
False
)
->
str
:
compile_flags
:
list
[
str
]
|
None
=
None
,
verbose
:
bool
=
False
)
->
str
:
"""
"""
Compile CUDA C++ source to CUBIN and disassemble to SASS.
Compile CUDA C++ source to CUBIN and disassemble to SASS.
...
@@ -246,9 +236,7 @@ def get_sass_from_source(code: str,
...
@@ -246,9 +236,7 @@ def get_sass_from_source(code: str,
cand_nvdisasm
=
_find_tool
(
"nvdisasm"
)
cand_nvdisasm
=
_find_tool
(
"nvdisasm"
)
cand_cuobjdump
=
_find_tool
(
"cuobjdump"
)
cand_cuobjdump
=
_find_tool
(
"cuobjdump"
)
if
not
cand_nvdisasm
and
not
cand_cuobjdump
:
if
not
cand_nvdisasm
and
not
cand_cuobjdump
:
raise
RuntimeError
(
raise
RuntimeError
(
"Cannot find 'nvdisasm' or 'cuobjdump'. Please ensure CUDA toolkit is installed and in PATH."
)
"Cannot find 'nvdisasm' or 'cuobjdump'. Please ensure CUDA toolkit is installed and in PATH."
)
last_err
:
str
|
None
=
None
last_err
:
str
|
None
=
None
try
:
try
:
# Attempt nvdisasm first
# Attempt nvdisasm first
...
@@ -268,8 +256,7 @@ def get_sass_from_source(code: str,
...
@@ -268,8 +256,7 @@ def get_sass_from_source(code: str,
return
text
return
text
last_err
=
f
"
{
tool_name
}
rc=
{
proc
.
returncode
}
, output:
\n
{
text
}
"
last_err
=
f
"
{
tool_name
}
rc=
{
proc
.
returncode
}
, output:
\n
{
text
}
"
# If we reach here, all attempts failed
# If we reach here, all attempts failed
raise
RuntimeError
(
f
"SASS disassembly failed. Tried tools: "
raise
RuntimeError
(
f
"SASS disassembly failed. Tried tools:
{
', '
.
join
(
name
for
name
,
_
in
tools_to_try
)
}
\n
{
last_err
or
''
}
"
)
f
"
{
', '
.
join
(
name
for
name
,
_
in
tools_to_try
)
}
\n
{
last_err
or
''
}
"
)
finally
:
finally
:
with
contextlib
.
suppress
(
Exception
):
with
contextlib
.
suppress
(
Exception
):
os
.
remove
(
cubin_path
)
os
.
remove
(
cubin_path
)
...
@@ -438,8 +425,7 @@ def get_target_compute_version(target=None):
...
@@ -438,8 +425,7 @@ def get_target_compute_version(target=None):
if
tvm
.
cuda
(
0
).
exist
:
if
tvm
.
cuda
(
0
).
exist
:
return
tvm
.
cuda
(
0
).
compute_version
return
tvm
.
cuda
(
0
).
compute_version
raise
ValueError
(
"No CUDA architecture was specified or GPU detected."
raise
ValueError
(
"No CUDA architecture was specified or GPU detected.Try specifying it by adding '-arch=sm_xx' to your target."
)
"Try specifying it by adding '-arch=sm_xx' to your target."
)
def
parse_compute_version
(
compute_version
)
->
tuple
[
int
,
int
]:
def
parse_compute_version
(
compute_version
)
->
tuple
[
int
,
int
]:
...
@@ -524,7 +510,8 @@ def have_tensorcore(compute_version=None, target=None):
...
@@ -524,7 +510,8 @@ def have_tensorcore(compute_version=None, target=None):
warnings
.
warn
(
warnings
.
warn
(
"Tensorcore will be disabled due to no CUDA architecture specified."
"Tensorcore will be disabled due to no CUDA architecture specified."
"Try specifying it by adding '-arch=sm_xx' to your target."
,
"Try specifying it by adding '-arch=sm_xx' to your target."
,
stacklevel
=
2
)
stacklevel
=
2
,
)
return
False
return
False
compute_version
=
target
.
attrs
[
"arch"
]
compute_version
=
target
.
attrs
[
"arch"
]
# Compute version will be in the form "sm_{major}{minor}"
# Compute version will be in the form "sm_{major}{minor}"
...
...
tilelang/contrib/nvrtc.py
View file @
29051439
...
@@ -11,11 +11,13 @@ def get_nvrtc_version() -> tuple[int, int]:
...
@@ -11,11 +11,13 @@ def get_nvrtc_version() -> tuple[int, int]:
return
(
major
,
minor
)
return
(
major
,
minor
)
def
compile_cuda
(
code
:
str
,
def
compile_cuda
(
target_format
:
Literal
[
"ptx"
,
"cubin"
]
=
"ptx"
,
code
:
str
,
arch
:
int
|
None
=
None
,
target_format
:
Literal
[
"ptx"
,
"cubin"
]
=
"ptx"
,
options
:
str
|
list
[
str
]
|
None
=
None
,
arch
:
int
|
None
=
None
,
verbose
:
bool
=
False
)
->
bytearray
:
options
:
str
|
list
[
str
]
|
None
=
None
,
verbose
:
bool
=
False
,
)
->
bytearray
:
"""Compile cuda code with NVRTC.
"""Compile cuda code with NVRTC.
Parameters
Parameters
...
@@ -43,8 +45,7 @@ def compile_cuda(code: str,
...
@@ -43,8 +45,7 @@ def compile_cuda(code: str,
if
arch
is
None
:
if
arch
is
None
:
# If None, then it will use `tvm.target.Target.current().arch`.
# If None, then it will use `tvm.target.Target.current().arch`.
# Target arch could be a str like "80", "90", "90a", etc.
# Target arch could be a str like "80", "90", "90a", etc.
major
,
minor
=
parse_compute_version
(
major
,
minor
=
parse_compute_version
(
get_target_compute_version
(
Target
.
current
(
allow_none
=
True
)))
get_target_compute_version
(
Target
.
current
(
allow_none
=
True
)))
arch
=
major
*
10
+
minor
arch
=
major
*
10
+
minor
prefix
=
"compute"
if
target_format
==
"ptx"
else
"sm"
prefix
=
"compute"
if
target_format
==
"ptx"
else
"sm"
suffix
=
"a"
if
arch
>=
90
else
""
suffix
=
"a"
if
arch
>=
90
else
""
...
@@ -77,8 +78,7 @@ def compile_cuda(code: str,
...
@@ -77,8 +78,7 @@ def compile_cuda(code: str,
compile_result
=
nvrtc
.
nvrtcCompileProgram
(
program
,
len
(
options_bytes
),
options_bytes
)[
0
]
compile_result
=
nvrtc
.
nvrtcCompileProgram
(
program
,
len
(
options_bytes
),
options_bytes
)[
0
]
if
compile_result
!=
nvrtc
.
nvrtcResult
.
NVRTC_SUCCESS
:
if
compile_result
!=
nvrtc
.
nvrtcResult
.
NVRTC_SUCCESS
:
msg
=
f
"
{
code
}
\n
"
\
msg
=
f
"
{
code
}
\n
Compilation error:
\n
"
f
"Compilation error:
\n
"
if
verbose
:
if
verbose
:
result
,
log_size
=
nvrtc
.
nvrtcGetProgramLogSize
(
program
)
result
,
log_size
=
nvrtc
.
nvrtcGetProgramLogSize
(
program
)
assert
result
==
nvrtc
.
nvrtcResult
.
NVRTC_SUCCESS
,
f
"Failed to get program log size:
{
result
}
"
assert
result
==
nvrtc
.
nvrtcResult
.
NVRTC_SUCCESS
,
f
"Failed to get program log size:
{
result
}
"
...
@@ -105,7 +105,6 @@ def compile_cuda(code: str,
...
@@ -105,7 +105,6 @@ def compile_cuda(code: str,
assert
result
==
nvrtc
.
nvrtcResult
.
NVRTC_SUCCESS
,
f
"Failed to get PTX:
{
result
}
"
assert
result
==
nvrtc
.
nvrtcResult
.
NVRTC_SUCCESS
,
f
"Failed to get PTX:
{
result
}
"
# Destroy handler
# Destroy handler
assert
nvrtc
.
nvrtcDestroyProgram
(
assert
nvrtc
.
nvrtcDestroyProgram
(
program
)[
0
]
==
nvrtc
.
nvrtcResult
.
NVRTC_SUCCESS
,
f
"Failed to destroy program:
{
result
}
"
program
)[
0
]
==
nvrtc
.
nvrtcResult
.
NVRTC_SUCCESS
,
f
"Failed to destroy program:
{
result
}
"
return
result_bytes
return
result_bytes
tilelang/contrib/rocm.py
View file @
29051439
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# specific language governing permissions and limitations
# under the License.
# under the License.
"""Utility for ROCm backend"""
"""Utility for ROCm backend"""
# ruff: noqa
# ruff: noqa
import
re
import
re
import
subprocess
import
subprocess
...
@@ -255,9 +256,11 @@ def get_rocm_arch(rocm_path="/opt/rocm"):
...
@@ -255,9 +256,11 @@ def get_rocm_arch(rocm_path="/opt/rocm"):
gpu_arch
=
match
.
group
(
1
)
gpu_arch
=
match
.
group
(
1
)
return
gpu_arch
return
gpu_arch
except
subprocess
.
CalledProcessError
:
except
subprocess
.
CalledProcessError
:
print
(
f
"Unable to execute rocminfo command,
\
print
(
f
"Unable to execute rocminfo command,
\
please ensure ROCm is installed and you have an AMD GPU on your system.
\
please ensure ROCm is installed and you have an AMD GPU on your system.
\
using default
{
gpu_arch
}
."
)
using default
{
gpu_arch
}
."
)
return
gpu_arch
return
gpu_arch
...
...
tilelang/engine/lower.py
View file @
29051439
"""The compiler for TL programs."""
"""The compiler for TL programs."""
from
__future__
import
annotations
from
__future__
import
annotations
import
os
import
os
...
@@ -28,14 +29,13 @@ def is_cpu_device_backend(target: Target):
...
@@ -28,14 +29,13 @@ def is_cpu_device_backend(target: Target):
def
has_device_kernel_launch
(
attrs
)
->
bool
:
def
has_device_kernel_launch
(
attrs
)
->
bool
:
"""Check if the attributes indicate a device kernel launch."""
"""Check if the attributes indicate a device kernel launch."""
return
bool
(
attrs
and
"calling_conv"
in
attrs
and
return
bool
(
attrs
and
"calling_conv"
in
attrs
and
attrs
[
"calling_conv"
]
==
CallingConv
.
DEVICE_KERNEL_LAUNCH
)
attrs
[
"calling_conv"
]
==
CallingConv
.
DEVICE_KERNEL_LAUNCH
)
def
is_device_call_c_device
(
func
:
tir
.
PrimFunc
):
def
is_device_call_c_device
(
func
:
tir
.
PrimFunc
):
attrs
=
func
.
attrs
attrs
=
func
.
attrs
calling_conv
=
attrs
.
get
(
"calling_conv"
,
CallingConv
.
DEFAULT
)
calling_conv
=
attrs
.
get
(
"calling_conv"
,
CallingConv
.
DEFAULT
)
is_cpacked
=
(
calling_conv
==
CallingConv
.
C_PACKED_FUNC
)
is_cpacked
=
calling_conv
==
CallingConv
.
C_PACKED_FUNC
# Check if it's a C target
# Check if it's a C target
if
"target"
in
attrs
and
attrs
[
"target"
].
kind
.
name
==
"c"
and
not
is_cpacked
:
if
"target"
in
attrs
and
attrs
[
"target"
].
kind
.
name
==
"c"
and
not
is_cpacked
:
...
@@ -141,16 +141,16 @@ def extrac_params(func: tir.PrimFunc) -> list[KernelParam]:
...
@@ -141,16 +141,16 @@ def extrac_params(func: tir.PrimFunc) -> list[KernelParam]:
if
var
in
func
.
buffer_map
:
if
var
in
func
.
buffer_map
:
tensor_types
.
append
(
KernelParam
.
from_buffer
(
func
.
buffer_map
[
var
]))
tensor_types
.
append
(
KernelParam
.
from_buffer
(
func
.
buffer_map
[
var
]))
else
:
else
:
if
var
.
dtype
==
'
handle
'
:
if
var
.
dtype
==
"
handle
"
:
raise
ValueError
(
raise
ValueError
(
f
'Handle parameter
{
var
}
must be mapped to a buffer.
\n
'
f
"Handle parameter
{
var
}
must be mapped to a buffer.
\n
"
f
'Please use T.tensor(
{
var
.
name
}
, shape=..., dtype=...) to map it to a buffer.'
)
f
"Please use T.tensor(
{
var
.
name
}
, shape=..., dtype=...) to map it to a buffer."
)
tensor_types
.
append
(
KernelParam
.
from_var
(
var
))
tensor_types
.
append
(
KernelParam
.
from_var
(
var
))
return
tensor_types
return
tensor_types
def
canon_target_host
(
target
:
str
|
Target
,
target_host
:
str
|
Target
|
None
):
def
canon_target_host
(
target
:
str
|
Target
,
target_host
:
str
|
Target
|
None
):
if
not
target_host
:
if
not
target_host
:
target_host
=
"llvm"
if
tvm
.
runtime
.
enabled
(
"llvm"
)
else
"c"
target_host
=
"llvm"
if
tvm
.
runtime
.
enabled
(
"llvm"
)
else
"c"
...
@@ -195,11 +195,9 @@ def device_codegen_without_compile(device_mod: tvm.IRModule, target: Target) ->
...
@@ -195,11 +195,9 @@ def device_codegen_without_compile(device_mod: tvm.IRModule, target: Target) ->
device_mod
=
tilelang
.
transform
.
LowerIntrin
()(
device_mod
)
device_mod
=
tilelang
.
transform
.
LowerIntrin
()(
device_mod
)
device_mod
=
tir
.
transform
.
Simplify
()(
device_mod
)
device_mod
=
tir
.
transform
.
Simplify
()(
device_mod
)
if
target
.
kind
.
name
==
"cuda"
:
if
target
.
kind
.
name
==
"cuda"
:
device_mod
=
tvm
.
ffi
.
get_global_func
(
"target.build.tilelang_cuda_without_compile"
)(
device_mod
=
tvm
.
ffi
.
get_global_func
(
"target.build.tilelang_cuda_without_compile"
)(
device_mod
,
target
)
device_mod
,
target
)
elif
target
.
kind
.
name
==
"hip"
:
elif
target
.
kind
.
name
==
"hip"
:
device_mod
=
tvm
.
ffi
.
get_global_func
(
"target.build.tilelang_hip_without_compile"
)(
device_mod
=
tvm
.
ffi
.
get_global_func
(
"target.build.tilelang_hip_without_compile"
)(
device_mod
,
target
)
device_mod
,
target
)
elif
target
.
kind
.
name
==
"c"
:
elif
target
.
kind
.
name
==
"c"
:
device_mod
=
tvm
.
ffi
.
get_global_func
(
"target.build.tilelang_cpp"
)(
device_mod
,
target
)
device_mod
=
tvm
.
ffi
.
get_global_func
(
"target.build.tilelang_cpp"
)(
device_mod
,
target
)
elif
target
.
kind
.
name
==
"llvm"
:
elif
target
.
kind
.
name
==
"llvm"
:
...
@@ -222,12 +220,12 @@ def lower(
...
@@ -222,12 +220,12 @@ def lower(
enable_host_codegen
=
False
,
enable_host_codegen
=
False
,
enable_device_compile
=
False
,
enable_device_compile
=
False
,
)
->
CompiledArtifact
:
)
->
CompiledArtifact
:
'''
"""
enable_host_codegen: whether to enable host codegen, default is False, as we have our
enable_host_codegen: whether to enable host codegen, default is False, as we have our
own host codegen implementation in jit.
own host codegen implementation in jit.
enable_device_compile: whether to enable device codegen, default is False, as we have our
enable_device_compile: whether to enable device codegen, default is False, as we have our
own device codegen implementation in jit.
own device codegen implementation in jit.
'''
"""
mod
=
func_or_mod
mod
=
func_or_mod
params
=
None
params
=
None
...
@@ -259,14 +257,11 @@ def lower(
...
@@ -259,14 +257,11 @@ def lower(
host_mod
=
tir
.
transform
.
Filter
(
_is_host_call
)(
mod
)
host_mod
=
tir
.
transform
.
Filter
(
_is_host_call
)(
mod
)
device_mod
=
tir
.
transform
.
Filter
(
_is_device_call
)(
mod
)
device_mod
=
tir
.
transform
.
Filter
(
_is_device_call
)(
mod
)
codegen_mod
=
device_codegen
(
codegen_mod
=
device_codegen
(
device_mod
,
target
)
if
enable_device_compile
else
device_codegen_without_compile
(
device_mod
,
target
)
device_mod
,
target
)
if
enable_device_compile
else
device_codegen_without_compile
(
device_mod
,
target
)
if
enable_host_codegen
:
if
enable_host_codegen
:
host_mod
=
host_codegen
(
host_mod
,
target_host
)
host_mod
=
host_codegen
(
host_mod
,
target_host
)
host_mod
.
import_module
(
codegen_mod
)
host_mod
.
import_module
(
codegen_mod
)
return
CompiledArtifact
(
return
CompiledArtifact
(
host_mod
,
device_mod
,
params
,
codegen_mod
.
inspect_source
(),
rt_mod
=
host_mod
)
host_mod
,
device_mod
,
params
,
codegen_mod
.
inspect_source
(),
rt_mod
=
host_mod
)
return
CompiledArtifact
(
host_mod
,
device_mod
,
params
,
codegen_mod
.
inspect_source
())
return
CompiledArtifact
(
host_mod
,
device_mod
,
params
,
codegen_mod
.
inspect_source
())
tilelang/engine/param.py
View file @
29051439
"""The profiler and convert to torch utils"""
"""The profiler and convert to torch utils"""
from
__future__
import
annotations
from
__future__
import
annotations
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
...
@@ -14,6 +15,7 @@ class KernelParam:
...
@@ -14,6 +15,7 @@ class KernelParam:
Represents parameters for a kernel operation, storing dtype and shape information.
Represents parameters for a kernel operation, storing dtype and shape information.
Used to describe tensor or scalar parameters in TVM/PyTorch interop.
Used to describe tensor or scalar parameters in TVM/PyTorch interop.
"""
"""
dtype
:
torch
.
dtype
# PyTorch data type of the parameter
dtype
:
torch
.
dtype
# PyTorch data type of the parameter
shape
:
list
[
int
|
Var
]
# List of dimensions, can be integers or TVM variables
shape
:
list
[
int
|
Var
]
# List of dimensions, can be integers or TVM variables
...
@@ -109,6 +111,7 @@ class CompiledArtifact:
...
@@ -109,6 +111,7 @@ class CompiledArtifact:
Represents a compiled kernel artifact containing both host and device code.
Represents a compiled kernel artifact containing both host and device code.
Stores all necessary components for kernel execution in the TVM runtime.
Stores all necessary components for kernel execution in the TVM runtime.
"""
"""
host_mod
:
tvm
.
IRModule
# Host-side TVM IR module for managing kernel execution
host_mod
:
tvm
.
IRModule
# Host-side TVM IR module for managing kernel execution
device_mod
:
tvm
.
IRModule
# Device-side TVM IR module containing the actual kernel code
device_mod
:
tvm
.
IRModule
# Device-side TVM IR module containing the actual kernel code
params
:
list
[
KernelParam
]
# List of parameters (tensors/scalars) used by the kernel
params
:
list
[
KernelParam
]
# List of parameters (tensors/scalars) used by the kernel
...
...
tilelang/engine/phase.py
View file @
29051439
...
@@ -6,8 +6,7 @@ from tilelang.transform import PassContext
...
@@ -6,8 +6,7 @@ from tilelang.transform import PassContext
from
tilelang.contrib.nvcc
import
have_tma
,
is_hopper
from
tilelang.contrib.nvcc
import
have_tma
,
is_hopper
def
allow_warp_specialized
(
pass_ctx
:
PassContext
|
None
=
None
,
def
allow_warp_specialized
(
pass_ctx
:
PassContext
|
None
=
None
,
target
:
Target
|
None
=
None
)
->
bool
:
target
:
Target
|
None
=
None
)
->
bool
:
# avoid circular import
# avoid circular import
from
tilelang.jit.adapter.utils
import
is_cuda_target
from
tilelang.jit.adapter.utils
import
is_cuda_target
...
@@ -19,8 +18,7 @@ def allow_warp_specialized(pass_ctx: PassContext | None = None,
...
@@ -19,8 +18,7 @@ def allow_warp_specialized(pass_ctx: PassContext | None = None,
return
not
disable_warp_specialized
return
not
disable_warp_specialized
def
allow_tma_and_warp_specialized
(
pass_ctx
:
PassContext
|
None
=
None
,
def
allow_tma_and_warp_specialized
(
pass_ctx
:
PassContext
|
None
=
None
,
target
:
Target
|
None
=
None
)
->
bool
:
target
:
Target
|
None
=
None
)
->
bool
:
if
pass_ctx
is
None
:
if
pass_ctx
is
None
:
pass_ctx
=
tilelang
.
transform
.
get_pass_context
()
pass_ctx
=
tilelang
.
transform
.
get_pass_context
()
if
not
have_tma
(
target
):
if
not
have_tma
(
target
):
...
@@ -47,12 +45,10 @@ def allow_global_thread_synchronization(pass_ctx: PassContext | None = None) ->
...
@@ -47,12 +45,10 @@ def allow_global_thread_synchronization(pass_ctx: PassContext | None = None) ->
return
enable_global_thread_sync
return
enable_global_thread_sync
def
should_enable_aggressive_merge
(
pass_ctx
:
PassContext
|
None
=
None
,
def
should_enable_aggressive_merge
(
pass_ctx
:
PassContext
|
None
=
None
,
target
:
Target
|
None
=
None
)
->
bool
:
target
:
Target
|
None
=
None
)
->
bool
:
if
pass_ctx
is
None
:
if
pass_ctx
is
None
:
pass_ctx
=
tilelang
.
transform
.
get_pass_context
()
pass_ctx
=
tilelang
.
transform
.
get_pass_context
()
enable_aggressive_merge
=
bool
(
enable_aggressive_merge
=
bool
(
pass_ctx
.
config
.
get
(
tilelang
.
PassConfigKey
.
TL_ENABLE_AGGRESSIVE_SHARED_MEMORY_MERGE
,
False
))
pass_ctx
.
config
.
get
(
tilelang
.
PassConfigKey
.
TL_ENABLE_AGGRESSIVE_SHARED_MEMORY_MERGE
,
False
))
if
allow_warp_specialized
(
pass_ctx
=
pass_ctx
,
target
=
target
):
if
allow_warp_specialized
(
pass_ctx
=
pass_ctx
,
target
=
target
):
# This is a workaround to avoid the bug in the MergeSharedMemoryAllocations pass
# This is a workaround to avoid the bug in the MergeSharedMemoryAllocations pass
# when warp specialization is enabled, as different warp threads may access different
# when warp specialization is enabled, as different warp threads may access different
...
@@ -88,7 +84,7 @@ def get_layout_visual_formats(pass_ctx: PassContext | None = None) -> list[str]:
...
@@ -88,7 +84,7 @@ def get_layout_visual_formats(pass_ctx: PassContext | None = None) -> list[str]:
return
[
"txt"
,
"png"
,
"pdf"
,
"svg"
]
return
[
"txt"
,
"png"
,
"pdf"
,
"svg"
]
if
","
in
formats_str
:
if
","
in
formats_str
:
formats_list
=
[
f
.
strip
()
for
f
in
formats_str
.
split
(
','
)]
formats_list
=
[
f
.
strip
()
for
f
in
formats_str
.
split
(
","
)]
else
:
else
:
formats_list
=
[
formats_str
]
formats_list
=
[
formats_str
]
...
@@ -257,9 +253,7 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
...
@@ -257,9 +253,7 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
# MergeSharedMemoryAllocations must be applied after SplitHostDevice
# MergeSharedMemoryAllocations must be applied after SplitHostDevice
# because the merged allocation site is at the beginning of each device function
# because the merged allocation site is at the beginning of each device function
enable_aggressive_merge
=
should_enable_aggressive_merge
(
pass_ctx
=
pass_ctx
,
target
=
target
)
enable_aggressive_merge
=
should_enable_aggressive_merge
(
pass_ctx
=
pass_ctx
,
target
=
target
)
mod
=
tilelang
.
transform
.
MergeSharedMemoryAllocations
(
mod
=
tilelang
.
transform
.
MergeSharedMemoryAllocations
(
enable_aggressive_merge
=
enable_aggressive_merge
)(
mod
)
enable_aggressive_merge
=
enable_aggressive_merge
)(
mod
)
mod
=
tilelang
.
transform
.
ThreadSync
(
"shared"
)(
mod
)
mod
=
tilelang
.
transform
.
ThreadSync
(
"shared"
)(
mod
)
mod
=
tilelang
.
transform
.
ThreadSync
(
"shared.dyn"
)(
mod
)
mod
=
tilelang
.
transform
.
ThreadSync
(
"shared.dyn"
)(
mod
)
# Inject PTX async copy must behind the thread sync pass
# Inject PTX async copy must behind the thread sync pass
...
...
tilelang/env.py
View file @
29051439
...
@@ -10,36 +10,34 @@ from dataclasses import dataclass
...
@@ -10,36 +10,34 @@ from dataclasses import dataclass
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
# SETUP ENVIRONMENT VARIABLES
# SETUP ENVIRONMENT VARIABLES
CUTLASS_NOT_FOUND_MESSAGE
=
(
"CUTLASS is not installed or found in the expected path"
)
CUTLASS_NOT_FOUND_MESSAGE
=
"CUTLASS is not installed or found in the expected path"
", which may lead to compilation bugs when utilize tilelang backend."
", which may lead to compilation bugs when utilize tilelang backend."
COMPOSABLE_KERNEL_NOT_FOUND_MESSAGE
=
(
COMPOSABLE_KERNEL_NOT_FOUND_MESSAGE
=
"Composable Kernel is not installed or found in the expected path"
"Composable Kernel is not installed or found in the expected path"
)
", which may lead to compilation bugs when utilize tilelang backend."
", which may lead to compilation bugs when utilize tilelang backend."
TL_TEMPLATE_NOT_FOUND_MESSAGE
=
(
"TileLang is not installed or found in the expected path"
)
TL_TEMPLATE_NOT_FOUND_MESSAGE
=
"TileLang is not installed or found in the expected path"
", which may lead to compilation bugs when utilize tilelang backend."
", which may lead to compilation bugs when utilize tilelang backend."
TVM_LIBRARY_NOT_FOUND_MESSAGE
=
(
"TVM is not installed or found in the expected path"
)
TVM_LIBRARY_NOT_FOUND_MESSAGE
=
"TVM is not installed or found in the expected path"
TL_ROOT
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
TL_ROOT
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
# Only expose the internal lib directory to sys.path to avoid shadowing
# Only expose the internal lib directory to sys.path to avoid shadowing
# common top-level module names (e.g., utils, analysis) from user projects.
# common top-level module names (e.g., utils, analysis) from user projects.
TL_LIBS
=
[
os
.
path
.
join
(
TL_ROOT
,
'
lib
'
)]
TL_LIBS
=
[
os
.
path
.
join
(
TL_ROOT
,
"
lib
"
)]
TL_LIBS
=
[
i
for
i
in
TL_LIBS
if
os
.
path
.
exists
(
i
)]
TL_LIBS
=
[
i
for
i
in
TL_LIBS
if
os
.
path
.
exists
(
i
)]
DEV
=
False
DEV
=
False
THIRD_PARTY_ROOT
=
os
.
path
.
join
(
TL_ROOT
,
'
3rdparty
'
)
THIRD_PARTY_ROOT
=
os
.
path
.
join
(
TL_ROOT
,
"
3rdparty
"
)
if
not
os
.
path
.
exists
(
THIRD_PARTY_ROOT
):
if
not
os
.
path
.
exists
(
THIRD_PARTY_ROOT
):
DEV
=
True
DEV
=
True
tl_dev_root
=
os
.
path
.
dirname
(
TL_ROOT
)
tl_dev_root
=
os
.
path
.
dirname
(
TL_ROOT
)
dev_lib_root
=
os
.
path
.
join
(
tl_dev_root
,
'
build
'
)
dev_lib_root
=
os
.
path
.
join
(
tl_dev_root
,
"
build
"
)
# In dev builds, place artifacts under build/lib and point search path there
# In dev builds, place artifacts under build/lib and point search path there
# to avoid adding the entire build root to sys.path.
# to avoid adding the entire build root to sys.path.
TL_LIBS
=
[
os
.
path
.
join
(
dev_lib_root
,
'
lib
'
),
os
.
path
.
join
(
dev_lib_root
,
'
tvm
'
)]
TL_LIBS
=
[
os
.
path
.
join
(
dev_lib_root
,
"
lib
"
),
os
.
path
.
join
(
dev_lib_root
,
"
tvm
"
)]
THIRD_PARTY_ROOT
=
os
.
path
.
join
(
tl_dev_root
,
'
3rdparty
'
)
THIRD_PARTY_ROOT
=
os
.
path
.
join
(
tl_dev_root
,
"
3rdparty
"
)
logger
.
warning
(
f
'
Loading tilelang libs from dev root:
{
dev_lib_root
}
'
)
logger
.
warning
(
f
"
Loading tilelang libs from dev root:
{
dev_lib_root
}
"
)
assert
TL_LIBS
and
all
(
assert
TL_LIBS
and
all
(
os
.
path
.
exists
(
i
)
for
i
in
TL_LIBS
),
f
"tilelang lib root do not exists:
{
TL_LIBS
}
"
os
.
path
.
exists
(
i
)
for
i
in
TL_LIBS
),
f
'tilelang lib root do not exists:
{
TL_LIBS
}
'
for
lib
in
TL_LIBS
:
for
lib
in
TL_LIBS
:
if
lib
not
in
sys
.
path
:
if
lib
not
in
sys
.
path
:
...
@@ -52,7 +50,7 @@ def _find_cuda_home() -> str:
...
@@ -52,7 +50,7 @@ def _find_cuda_home() -> str:
Adapted from https://github.com/pytorch/pytorch/blob/main/torch/utils/cpp_extension.py
Adapted from https://github.com/pytorch/pytorch/blob/main/torch/utils/cpp_extension.py
"""
"""
# Guess #1
# Guess #1
cuda_home
=
os
.
environ
.
get
(
'
CUDA_HOME
'
)
or
os
.
environ
.
get
(
'
CUDA_PATH
'
)
cuda_home
=
os
.
environ
.
get
(
"
CUDA_HOME
"
)
or
os
.
environ
.
get
(
"
CUDA_PATH
"
)
if
cuda_home
is
None
:
if
cuda_home
is
None
:
# Guess #2
# Guess #2
nvcc_path
=
shutil
.
which
(
"nvcc"
)
nvcc_path
=
shutil
.
which
(
"nvcc"
)
...
@@ -70,15 +68,15 @@ def _find_cuda_home() -> str:
...
@@ -70,15 +68,15 @@ def _find_cuda_home() -> str:
else
:
else
:
# Guess #3
# Guess #3
if
sys
.
platform
==
'
win32
'
:
if
sys
.
platform
==
"
win32
"
:
cuda_homes
=
glob
.
glob
(
'
C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v*.*
'
)
cuda_homes
=
glob
.
glob
(
"
C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v*.*
"
)
cuda_home
=
''
if
len
(
cuda_homes
)
==
0
else
cuda_homes
[
0
]
cuda_home
=
""
if
len
(
cuda_homes
)
==
0
else
cuda_homes
[
0
]
else
:
else
:
# Linux/macOS
# Linux/macOS
if
os
.
path
.
exists
(
'
/usr/local/cuda
'
):
if
os
.
path
.
exists
(
"
/usr/local/cuda
"
):
cuda_home
=
'
/usr/local/cuda
'
cuda_home
=
"
/usr/local/cuda
"
elif
os
.
path
.
exists
(
'
/opt/nvidia/hpc_sdk/Linux_x86_64
'
):
elif
os
.
path
.
exists
(
"
/opt/nvidia/hpc_sdk/Linux_x86_64
"
):
cuda_home
=
'
/opt/nvidia/hpc_sdk/Linux_x86_64
'
cuda_home
=
"
/opt/nvidia/hpc_sdk/Linux_x86_64
"
# Validate found path
# Validate found path
if
cuda_home
is
None
or
not
os
.
path
.
exists
(
cuda_home
):
if
cuda_home
is
None
or
not
os
.
path
.
exists
(
cuda_home
):
...
@@ -89,13 +87,13 @@ def _find_cuda_home() -> str:
...
@@ -89,13 +87,13 @@ def _find_cuda_home() -> str:
def
_find_rocm_home
()
->
str
:
def
_find_rocm_home
()
->
str
:
"""Find the ROCM install path."""
"""Find the ROCM install path."""
rocm_home
=
os
.
environ
.
get
(
'
ROCM_PATH
'
)
or
os
.
environ
.
get
(
'
ROCM_HOME
'
)
rocm_home
=
os
.
environ
.
get
(
"
ROCM_PATH
"
)
or
os
.
environ
.
get
(
"
ROCM_HOME
"
)
if
rocm_home
is
None
:
if
rocm_home
is
None
:
rocmcc_path
=
shutil
.
which
(
"hipcc"
)
rocmcc_path
=
shutil
.
which
(
"hipcc"
)
if
rocmcc_path
is
not
None
:
if
rocmcc_path
is
not
None
:
rocm_home
=
os
.
path
.
dirname
(
os
.
path
.
dirname
(
rocmcc_path
))
rocm_home
=
os
.
path
.
dirname
(
os
.
path
.
dirname
(
rocmcc_path
))
else
:
else
:
rocm_home
=
'
/opt/rocm
'
rocm_home
=
"
/opt/rocm
"
if
not
os
.
path
.
exists
(
rocm_home
):
if
not
os
.
path
.
exists
(
rocm_home
):
rocm_home
=
None
rocm_home
=
None
return
rocm_home
if
rocm_home
is
not
None
else
""
return
rocm_home
if
rocm_home
is
not
None
else
""
...
@@ -104,6 +102,7 @@ def _find_rocm_home() -> str:
...
@@ -104,6 +102,7 @@ def _find_rocm_home() -> str:
# Cache control
# Cache control
class
CacheState
:
class
CacheState
:
"""Class to manage global kernel caching state."""
"""Class to manage global kernel caching state."""
_enabled
=
True
_enabled
=
True
@
classmethod
@
classmethod
...
@@ -230,13 +229,11 @@ class Environment:
...
@@ -230,13 +229,11 @@ class Environment:
TILELANG_TMP_DIR
=
EnvVar
(
"TILELANG_TMP_DIR"
,
os
.
path
.
join
(
TILELANG_CACHE_DIR
.
get
(),
"tmp"
))
TILELANG_TMP_DIR
=
EnvVar
(
"TILELANG_TMP_DIR"
,
os
.
path
.
join
(
TILELANG_CACHE_DIR
.
get
(),
"tmp"
))
# Kernel Build options
# Kernel Build options
TILELANG_PRINT_ON_COMPILATION
=
EnvVar
(
"TILELANG_PRINT_ON_COMPILATION"
,
TILELANG_PRINT_ON_COMPILATION
=
EnvVar
(
"TILELANG_PRINT_ON_COMPILATION"
,
"1"
)
# print kernel name on compile
"1"
)
# print kernel name on compile
TILELANG_DISABLE_CACHE
=
EnvVar
(
TILELANG_DISABLE_CACHE
=
EnvVar
(
"TILELANG_DISABLE_CACHE"
,
"TILELANG_DISABLE_CACHE"
,
"0"
"0"
)
# disable kernel cache, usually for unit testing / debugging, high priority
)
# disable kernel cache, usually for unit testing / debugging, high priority
TILELANG_CLEAR_CACHE
=
EnvVar
(
"TILELANG_CLEAR_CACHE"
,
TILELANG_CLEAR_CACHE
=
EnvVar
(
"TILELANG_CLEAR_CACHE"
,
"0"
)
# DEPRECATED! clear cache automatically if set
"0"
)
# DEPRECATED! clear cache automatically if set
# Kernel selection options
# Kernel selection options
# Default to GEMM v2; set to "1"/"true"/"yes"/"on" to force v1
# Default to GEMM v2; set to "1"/"true"/"yes"/"on" to force v1
...
@@ -244,12 +241,9 @@ class Environment:
...
@@ -244,12 +241,9 @@ class Environment:
# Auto-tuning settings
# Auto-tuning settings
TILELANG_AUTO_TUNING_DISABLE_CACHE
=
EnvVar
(
"TILELANG_AUTO_TUNING_DISABLE_CACHE"
,
"0"
)
TILELANG_AUTO_TUNING_DISABLE_CACHE
=
EnvVar
(
"TILELANG_AUTO_TUNING_DISABLE_CACHE"
,
"0"
)
TILELANG_AUTO_TUNING_CPU_UTILITIES
=
EnvVar
(
"TILELANG_AUTO_TUNING_CPU_UTILITIES"
,
TILELANG_AUTO_TUNING_CPU_UTILITIES
=
EnvVar
(
"TILELANG_AUTO_TUNING_CPU_UTILITIES"
,
"0.9"
)
# percent of CPUs used
"0.9"
)
# percent of CPUs used
TILELANG_AUTO_TUNING_CPU_COUNTS
=
EnvVar
(
"TILELANG_AUTO_TUNING_CPU_COUNTS"
,
"-1"
)
# -1 means auto
TILELANG_AUTO_TUNING_CPU_COUNTS
=
EnvVar
(
"TILELANG_AUTO_TUNING_CPU_COUNTS"
,
TILELANG_AUTO_TUNING_MAX_CPU_COUNT
=
EnvVar
(
"TILELANG_AUTO_TUNING_MAX_CPU_COUNT"
,
"-1"
)
# -1 means no limit
"-1"
)
# -1 means auto
TILELANG_AUTO_TUNING_MAX_CPU_COUNT
=
EnvVar
(
"TILELANG_AUTO_TUNING_MAX_CPU_COUNT"
,
"-1"
)
# -1 means no limit
# TVM integration
# TVM integration
SKIP_LOADING_TILELANG_SO
=
EnvVar
(
"SKIP_LOADING_TILELANG_SO"
,
"0"
)
SKIP_LOADING_TILELANG_SO
=
EnvVar
(
"SKIP_LOADING_TILELANG_SO"
,
"0"
)
...
@@ -323,18 +317,18 @@ def prepend_pythonpath(path):
...
@@ -323,18 +317,18 @@ def prepend_pythonpath(path):
if
env
.
TVM_IMPORT_PYTHON_PATH
is
not
None
:
if
env
.
TVM_IMPORT_PYTHON_PATH
is
not
None
:
prepend_pythonpath
(
env
.
TVM_IMPORT_PYTHON_PATH
)
prepend_pythonpath
(
env
.
TVM_IMPORT_PYTHON_PATH
)
else
:
else
:
tvm_path
=
os
.
path
.
join
(
THIRD_PARTY_ROOT
,
'
tvm
'
,
'
python
'
)
tvm_path
=
os
.
path
.
join
(
THIRD_PARTY_ROOT
,
"
tvm
"
,
"
python
"
)
assert
os
.
path
.
exists
(
tvm_path
),
tvm_path
assert
os
.
path
.
exists
(
tvm_path
),
tvm_path
if
tvm_path
not
in
sys
.
path
:
if
tvm_path
not
in
sys
.
path
:
prepend_pythonpath
(
tvm_path
)
prepend_pythonpath
(
tvm_path
)
env
.
TVM_IMPORT_PYTHON_PATH
=
tvm_path
env
.
TVM_IMPORT_PYTHON_PATH
=
tvm_path
# By default, the built TVM-related libraries are stored in TL_LIBS.
# By default, the built TVM-related libraries are stored in TL_LIBS.
if
os
.
environ
.
get
(
"TVM_LIBRARY_PATH"
)
is
None
:
if
os
.
environ
.
get
(
"TVM_LIBRARY_PATH"
)
is
None
:
os
.
environ
[
'
TVM_LIBRARY_PATH
'
]
=
env
.
TVM_LIBRARY_PATH
=
os
.
pathsep
.
join
(
TL_LIBS
)
os
.
environ
[
"
TVM_LIBRARY_PATH
"
]
=
env
.
TVM_LIBRARY_PATH
=
os
.
pathsep
.
join
(
TL_LIBS
)
# Initialize CUTLASS paths
# Initialize CUTLASS paths
if
os
.
environ
.
get
(
"TL_CUTLASS_PATH"
,
None
)
is
None
:
if
os
.
environ
.
get
(
"TL_CUTLASS_PATH"
,
None
)
is
None
:
cutlass_inc_path
=
os
.
path
.
join
(
THIRD_PARTY_ROOT
,
'
cutlass
'
,
'
include
'
)
cutlass_inc_path
=
os
.
path
.
join
(
THIRD_PARTY_ROOT
,
"
cutlass
"
,
"
include
"
)
if
os
.
path
.
exists
(
cutlass_inc_path
):
if
os
.
path
.
exists
(
cutlass_inc_path
):
os
.
environ
[
"TL_CUTLASS_PATH"
]
=
env
.
CUTLASS_INCLUDE_DIR
=
cutlass_inc_path
os
.
environ
[
"TL_CUTLASS_PATH"
]
=
env
.
CUTLASS_INCLUDE_DIR
=
cutlass_inc_path
else
:
else
:
...
@@ -342,7 +336,7 @@ if os.environ.get("TL_CUTLASS_PATH", None) is None:
...
@@ -342,7 +336,7 @@ if os.environ.get("TL_CUTLASS_PATH", None) is None:
# Initialize COMPOSABLE_KERNEL paths
# Initialize COMPOSABLE_KERNEL paths
if
os
.
environ
.
get
(
"TL_COMPOSABLE_KERNEL_PATH"
,
None
)
is
None
:
if
os
.
environ
.
get
(
"TL_COMPOSABLE_KERNEL_PATH"
,
None
)
is
None
:
ck_inc_path
=
os
.
path
.
join
(
THIRD_PARTY_ROOT
,
'
composable_kernel
'
,
'
include
'
)
ck_inc_path
=
os
.
path
.
join
(
THIRD_PARTY_ROOT
,
"
composable_kernel
"
,
"
include
"
)
if
os
.
path
.
exists
(
ck_inc_path
):
if
os
.
path
.
exists
(
ck_inc_path
):
os
.
environ
[
"TL_COMPOSABLE_KERNEL_PATH"
]
=
env
.
COMPOSABLE_KERNEL_INCLUDE_DIR
=
ck_inc_path
os
.
environ
[
"TL_COMPOSABLE_KERNEL_PATH"
]
=
env
.
COMPOSABLE_KERNEL_INCLUDE_DIR
=
ck_inc_path
else
:
else
:
...
...
tilelang/intrinsics/mfma_layout.py
View file @
29051439
...
@@ -4,7 +4,7 @@ import tilelang.language as T
...
@@ -4,7 +4,7 @@ import tilelang.language as T
def
shared_16x4_to_local_64x1_layout_A
(
i
,
j
):
def
shared_16x4_to_local_64x1_layout_A
(
i
,
j
):
thread_id
=
(
j
*
16
+
i
)
thread_id
=
j
*
16
+
i
return
thread_id
,
convert
(
0
)
return
thread_id
,
convert
(
0
)
...
@@ -15,7 +15,7 @@ def thread_id_shared_access_64x1_to_16x4_layout_A(thread_id, local_id):
...
@@ -15,7 +15,7 @@ def thread_id_shared_access_64x1_to_16x4_layout_A(thread_id, local_id):
def
shared_4x16_to_local_64x1_layout_B
(
i
,
j
):
def
shared_4x16_to_local_64x1_layout_B
(
i
,
j
):
thread_id
=
(
i
*
16
+
j
)
thread_id
=
i
*
16
+
j
return
thread_id
,
convert
(
0
)
return
thread_id
,
convert
(
0
)
...
@@ -27,7 +27,7 @@ def thread_id_shared_access_64x1_to_4x16_layout_B(thread_id, local_id):
...
@@ -27,7 +27,7 @@ def thread_id_shared_access_64x1_to_4x16_layout_B(thread_id, local_id):
def
shared_16x16_to_local_64x4_layout_C
(
i
,
j
):
def
shared_16x16_to_local_64x4_layout_C
(
i
,
j
):
thread_id
=
j
+
(
i
//
4
)
*
16
thread_id
=
j
+
(
i
//
4
)
*
16
local
=
(
i
%
4
)
local
=
i
%
4
return
thread_id
,
local
return
thread_id
,
local
...
@@ -45,7 +45,7 @@ def thread_id_shared_access_64x4_to_16x16_layout_A(thread_id, local_id):
...
@@ -45,7 +45,7 @@ def thread_id_shared_access_64x4_to_16x16_layout_A(thread_id, local_id):
def
shared_16x16_to_local_64x4_layout_A
(
i
,
j
):
def
shared_16x16_to_local_64x4_layout_A
(
i
,
j
):
thread_id
=
i
+
16
*
(
j
//
4
)
thread_id
=
i
+
16
*
(
j
//
4
)
local
=
(
j
%
4
)
local
=
j
%
4
return
thread_id
,
local
return
thread_id
,
local
...
@@ -57,7 +57,7 @@ def thread_id_shared_access_64x4_to_16x16_layout_B(thread_id, local_id):
...
@@ -57,7 +57,7 @@ def thread_id_shared_access_64x4_to_16x16_layout_B(thread_id, local_id):
def
shared_16x16_to_local_64x4_layout_B
(
i
,
j
):
def
shared_16x16_to_local_64x4_layout_B
(
i
,
j
):
thread_id
=
j
+
(
i
//
4
)
*
16
thread_id
=
j
+
(
i
//
4
)
*
16
local
=
(
i
%
4
)
local
=
i
%
4
return
thread_id
,
local
return
thread_id
,
local
...
@@ -87,7 +87,7 @@ def thread_id_shared_access_64x8_to_16x32_layout_A(thread_id, local_id):
...
@@ -87,7 +87,7 @@ def thread_id_shared_access_64x8_to_16x32_layout_A(thread_id, local_id):
def
shared_16x32_to_local_64x8_layout_A
(
i
,
j
):
def
shared_16x32_to_local_64x8_layout_A
(
i
,
j
):
thread_id
=
i
+
16
*
(
j
//
8
)
thread_id
=
i
+
16
*
(
j
//
8
)
local
=
(
j
%
8
)
local
=
j
%
8
return
thread_id
,
local
return
thread_id
,
local
...
@@ -99,7 +99,7 @@ def thread_id_shared_access_64x8_to_16x32_layout_B(thread_id, local_id):
...
@@ -99,7 +99,7 @@ def thread_id_shared_access_64x8_to_16x32_layout_B(thread_id, local_id):
def
shared_16x32_to_local_64x8_layout_B
(
i
,
j
):
def
shared_16x32_to_local_64x8_layout_B
(
i
,
j
):
thread_id
=
j
+
(
i
//
8
)
*
16
thread_id
=
j
+
(
i
//
8
)
*
16
local
=
(
i
%
8
)
local
=
i
%
8
return
thread_id
,
local
return
thread_id
,
local
...
@@ -111,7 +111,7 @@ def thread_id_shared_access_64x16_to_16x64_layout_A(thread_id, local_id):
...
@@ -111,7 +111,7 @@ def thread_id_shared_access_64x16_to_16x64_layout_A(thread_id, local_id):
def
shared_16x64_to_local_64x16_layout_A
(
i
,
j
):
def
shared_16x64_to_local_64x16_layout_A
(
i
,
j
):
thread_id
=
i
+
16
*
(
j
//
16
)
thread_id
=
i
+
16
*
(
j
//
16
)
local
=
(
j
%
16
)
local
=
j
%
16
return
thread_id
,
local
return
thread_id
,
local
...
@@ -123,7 +123,7 @@ def thread_id_shared_access_64x16_to_16x64_layout_B(thread_id, local_id):
...
@@ -123,7 +123,7 @@ def thread_id_shared_access_64x16_to_16x64_layout_B(thread_id, local_id):
def
shared_16x64_to_local_64x16_layout_B
(
i
,
j
):
def
shared_16x64_to_local_64x16_layout_B
(
i
,
j
):
thread_id
=
i
+
16
*
(
j
//
16
)
thread_id
=
i
+
16
*
(
j
//
16
)
local
=
(
j
%
16
)
local
=
j
%
16
return
thread_id
,
local
return
thread_id
,
local
...
...
tilelang/intrinsics/mfma_macro_generator.py
View file @
29051439
...
@@ -6,7 +6,7 @@ from tvm import tir
...
@@ -6,7 +6,7 @@ from tvm import tir
from
tvm.ir
import
Range
from
tvm.ir
import
Range
from
tvm.tir
import
PrimExpr
,
IndexMap
,
Buffer
,
Var
,
BufferRegion
,
BufferLoad
from
tvm.tir
import
PrimExpr
,
IndexMap
,
Buffer
,
Var
,
BufferRegion
,
BufferLoad
from
tvm.runtime
import
convert
from
tvm.runtime
import
convert
from
.utils
import
(
mfma_store_index_map
)
from
.utils
import
mfma_store_index_map
from
typing
import
Literal
,
Callable
from
typing
import
Literal
,
Callable
from
tilelang.utils
import
is_fragment
from
tilelang.utils
import
is_fragment
...
@@ -101,7 +101,7 @@ class MatrixCoreIntrinEmitter:
...
@@ -101,7 +101,7 @@ class MatrixCoreIntrinEmitter:
self
.
warp_rows
=
warp_row_tiles
//
self
.
micro_size_x
self
.
warp_rows
=
warp_row_tiles
//
self
.
micro_size_x
self
.
warp_cols
=
warp_col_tiles
//
self
.
micro_size_y
self
.
warp_cols
=
warp_col_tiles
//
self
.
micro_size_y
self
.
reduce_k
=
reduce_k
self
.
reduce_k
=
reduce_k
self
.
threads
=
(
self
.
WARP_SIZE
*
(
block_row_warps
*
block_col_warps
)
*
reduce_k
)
self
.
threads
=
self
.
WARP_SIZE
*
(
block_row_warps
*
block_col_warps
)
*
reduce_k
self
.
num_elems_per_byte
=
num_elems_per_byte
self
.
num_elems_per_byte
=
num_elems_per_byte
self
.
thread_var
=
thread_var
self
.
thread_var
=
thread_var
...
@@ -132,12 +132,7 @@ class MatrixCoreIntrinEmitter:
...
@@ -132,12 +132,7 @@ class MatrixCoreIntrinEmitter:
def
_initialize_mfma_prefix
(
self
,
k_dim
=
16
):
def
_initialize_mfma_prefix
(
self
,
k_dim
=
16
):
in_dtype
,
out_dtype
=
self
.
a_dtype
,
self
.
accum_dtype
in_dtype
,
out_dtype
=
self
.
a_dtype
,
self
.
accum_dtype
M_DIM
,
N_DIM
=
self
.
M_DIM
,
self
.
N_DIM
M_DIM
,
N_DIM
=
self
.
M_DIM
,
self
.
N_DIM
out_dtype_abbrv
=
{
out_dtype_abbrv
=
{
"float16"
:
"f16"
,
"float32"
:
"f32"
,
"int8"
:
"i8"
,
"int32"
:
"i32"
}[
out_dtype
]
"float16"
:
"f16"
,
"float32"
:
"f32"
,
"int8"
:
"i8"
,
"int32"
:
"i32"
}[
out_dtype
]
in_dtype_abbrv
=
{
in_dtype_abbrv
=
{
"bfloat16"
:
"bf16"
,
"bfloat16"
:
"bf16"
,
...
@@ -176,7 +171,6 @@ class MatrixCoreIntrinEmitter:
...
@@ -176,7 +171,6 @@ class MatrixCoreIntrinEmitter:
self
.
b_preshuffle
=
b_preshuffle
self
.
b_preshuffle
=
b_preshuffle
def
get_ldmatrix_index_map
(
self
,
is_b
=
False
):
def
get_ldmatrix_index_map
(
self
,
is_b
=
False
):
k_dim
=
self
.
k_dim
*
self
.
k_pack
k_dim
=
self
.
k_dim
*
self
.
k_pack
transposed
=
self
.
a_transposed
if
not
is_b
else
self
.
b_transposed
transposed
=
self
.
a_transposed
if
not
is_b
else
self
.
b_transposed
if
k_dim
==
4
:
if
k_dim
==
4
:
...
@@ -184,28 +178,42 @@ class MatrixCoreIntrinEmitter:
...
@@ -184,28 +178,42 @@ class MatrixCoreIntrinEmitter:
reverse_index_map
=
thread_id_shared_access_64x1_to_16x4_layout_A
reverse_index_map
=
thread_id_shared_access_64x1_to_16x4_layout_A
if
is_b
:
if
is_b
:
index_map
=
shared_16x4_to_local_64x1_layout_A
if
transposed
else
shared_4x16_to_local_64x1_layout_B
index_map
=
shared_16x4_to_local_64x1_layout_A
if
transposed
else
shared_4x16_to_local_64x1_layout_B
reverse_index_map
=
thread_id_shared_access_64x1_to_16x4_layout_A
if
transposed
else
thread_id_shared_access_64x1_to_4x16_layout_B
reverse_index_map
=
(
thread_id_shared_access_64x1_to_16x4_layout_A
if
transposed
else
thread_id_shared_access_64x1_to_4x16_layout_B
)
elif
k_dim
==
16
:
elif
k_dim
==
16
:
index_map
=
shared_16x16_to_local_64x4_layout_B
if
transposed
else
shared_16x16_to_local_64x4_layout_A
index_map
=
shared_16x16_to_local_64x4_layout_B
if
transposed
else
shared_16x16_to_local_64x4_layout_A
reverse_index_map
=
thread_id_shared_access_64x4_to_16x16_layout_B
if
transposed
else
thread_id_shared_access_64x4_to_16x16_layout_A
reverse_index_map
=
(
thread_id_shared_access_64x4_to_16x16_layout_B
if
transposed
else
thread_id_shared_access_64x4_to_16x16_layout_A
)
if
is_b
:
if
is_b
:
index_map
=
shared_16x16_to_local_64x4_layout_A
if
transposed
else
shared_16x16_to_local_64x4_layout_B
index_map
=
shared_16x16_to_local_64x4_layout_A
if
transposed
else
shared_16x16_to_local_64x4_layout_B
reverse_index_map
=
thread_id_shared_access_64x4_to_16x16_layout_A
if
transposed
else
thread_id_shared_access_64x4_to_16x16_layout_B
reverse_index_map
=
(
thread_id_shared_access_64x4_to_16x16_layout_A
if
transposed
else
thread_id_shared_access_64x4_to_16x16_layout_B
)
elif
k_dim
==
32
:
elif
k_dim
==
32
:
index_map
=
shared_16x32_to_local_64x8_layout_B
if
transposed
else
shared_16x32_to_local_64x8_layout_A
index_map
=
shared_16x32_to_local_64x8_layout_B
if
transposed
else
shared_16x32_to_local_64x8_layout_A
reverse_index_map
=
thread_id_shared_access_64x8_to_16x32_layout_B
if
transposed
else
thread_id_shared_access_64x8_to_16x32_layout_A
reverse_index_map
=
(
thread_id_shared_access_64x8_to_16x32_layout_B
if
transposed
else
thread_id_shared_access_64x8_to_16x32_layout_A
)
if
is_b
:
if
is_b
:
index_map
=
shared_16x32_to_local_64x8_layout_A
if
transposed
else
shared_16x32_to_local_64x8_layout_B
index_map
=
shared_16x32_to_local_64x8_layout_A
if
transposed
else
shared_16x32_to_local_64x8_layout_B
reverse_index_map
=
thread_id_shared_access_64x8_to_16x32_layout_A
if
transposed
else
thread_id_shared_access_64x8_to_16x32_layout_B
reverse_index_map
=
(
thread_id_shared_access_64x8_to_16x32_layout_A
if
transposed
else
thread_id_shared_access_64x8_to_16x32_layout_B
)
elif
k_dim
==
64
:
elif
k_dim
==
64
:
index_map
=
shared_16x64_to_local_64x16_layout_B
if
transposed
else
shared_16x64_to_local_64x16_layout_A
index_map
=
shared_16x64_to_local_64x16_layout_B
if
transposed
else
shared_16x64_to_local_64x16_layout_A
reverse_index_map
=
thread_id_shared_access_64x16_to_16x64_layout_B
if
transposed
else
thread_id_shared_access_64x16_to_16x64_layout_A
reverse_index_map
=
(
thread_id_shared_access_64x16_to_16x64_layout_B
if
transposed
else
thread_id_shared_access_64x16_to_16x64_layout_A
)
if
is_b
:
if
is_b
:
index_map
=
shared_16x64_to_local_64x16_layout_A
if
transposed
else
shared_16x64_to_local_64x16_layout_B
index_map
=
shared_16x64_to_local_64x16_layout_A
if
transposed
else
shared_16x64_to_local_64x16_layout_B
reverse_index_map
=
thread_id_shared_access_64x16_to_16x64_layout_A
if
transposed
else
thread_id_shared_access_64x16_to_16x64_layout_B
reverse_index_map
=
(
thread_id_shared_access_64x16_to_16x64_layout_A
if
transposed
else
thread_id_shared_access_64x16_to_16x64_layout_B
)
else
:
else
:
raise
ValueError
(
"k_dim must be 4 or 16 or 32 or 64 currently"
)
raise
ValueError
(
"k_dim must be 4 or 16 or 32 or 64 currently"
)
...
@@ -227,14 +235,12 @@ class MatrixCoreIntrinEmitter:
...
@@ -227,14 +235,12 @@ class MatrixCoreIntrinEmitter:
else
:
else
:
return
self
.
thread_var
return
self
.
thread_var
def
extract_thread_binding
(
self
,
def
extract_thread_binding
(
self
,
thread_id
,
is_m_first
=
None
)
->
tuple
[
PrimExpr
,
PrimExpr
,
PrimExpr
]:
thread_id
,
"""
is_m_first
=
None
)
->
tuple
[
PrimExpr
,
PrimExpr
,
PrimExpr
]:
is_m_first: True if the thread binding is in the form of (tx, warp_n, warp_m)
'''
which represents [warp_size, block_row_warps (split n), block_col_warps (split m)]
is_m_first: True if the thread binding is in the form of (tx, warp_n, warp_m)
Otherwise, it is in the form of [warp_size, block_col_warps (split m), block_row_warps (split n)]
which represents [warp_size, block_row_warps (split n), block_col_warps (split m)]
"""
Otherwise, it is in the form of [warp_size, block_col_warps (split m), block_row_warps (split n)]
'''
WARP_SIZE
=
self
.
WARP_SIZE
WARP_SIZE
=
self
.
WARP_SIZE
block_row_warps
=
self
.
block_row_warps
block_row_warps
=
self
.
block_row_warps
block_col_warps
=
self
.
block_col_warps
block_col_warps
=
self
.
block_col_warps
...
@@ -244,16 +250,18 @@ class MatrixCoreIntrinEmitter:
...
@@ -244,16 +250,18 @@ class MatrixCoreIntrinEmitter:
is_m_first
=
self
.
is_m_first
is_m_first
=
self
.
is_m_first
if
is_m_first
:
if
is_m_first
:
lane_id
,
warp_n
,
warp_m
=
thread_id
%
WARP_SIZE
,
(
lane_id
,
warp_n
,
warp_m
=
(
thread_id
//
thread_id
%
WARP_SIZE
,
WARP_SIZE
)
%
block_col_warps
,
(
thread_id
//
(
thread_id
//
WARP_SIZE
)
%
block_col_warps
,
(
WARP_SIZE
*
block_col_warps
))
%
block_row_warps
,
(
thread_id
//
(
WARP_SIZE
*
block_col_warps
))
%
block_row_warps
,
)
return
lane_id
,
warp_n
,
warp_m
return
lane_id
,
warp_n
,
warp_m
else
:
else
:
lane_id
,
warp_m
,
warp_n
=
thread_id
%
WARP_SIZE
,
(
lane_id
,
warp_m
,
warp_n
=
(
thread_id
//
thread_id
%
WARP_SIZE
,
WARP_SIZE
)
%
block_row_warps
,
(
thread_id
//
(
thread_id
//
WARP_SIZE
)
%
block_row_warps
,
(
WARP_SIZE
*
block_row_warps
))
%
block_col_warps
,
(
thread_id
//
(
WARP_SIZE
*
block_row_warps
))
%
block_col_warps
,
)
return
lane_id
,
warp_n
,
warp_m
return
lane_id
,
warp_n
,
warp_m
def
ldmatrix_a
(
self
,
A_local_buf
,
A_shared_buf
:
Buffer
|
BufferRegion
,
ki
,
rk
=
0
):
def
ldmatrix_a
(
self
,
A_local_buf
,
A_shared_buf
:
Buffer
|
BufferRegion
,
ki
,
rk
=
0
):
...
@@ -287,18 +295,14 @@ class MatrixCoreIntrinEmitter:
...
@@ -287,18 +295,14 @@ class MatrixCoreIntrinEmitter:
for
i
in
T
.
serial
(
warp_rows
):
for
i
in
T
.
serial
(
warp_rows
):
for
local_id
in
T
.
vectorized
(
k_pack
*
local_size_a
):
for
local_id
in
T
.
vectorized
(
k_pack
*
local_size_a
):
row
,
col
=
T
.
meta_var
(
reverse_index_map
(
tx
,
local_id
))
row
,
col
=
T
.
meta_var
(
reverse_index_map
(
tx
,
local_id
))
l
,
r
=
(
rk
*
chunk
+
ki
*
(
k_pack
*
micro_size_k
),
l
,
r
=
(
rk
*
chunk
+
ki
*
(
k_pack
*
micro_size_k
),
warp_m
*
warp_row_tiles
+
i
*
micro_size_x
)
warp_m
*
warp_row_tiles
+
i
*
micro_size_x
)
A_local_buf
[
i
*
k_pack
*
local_size_a
+
local_id
]
=
A_buf
[
A_base0
+
l
+
row
,
A_base1
+
r
+
col
]
A_local_buf
[
i
*
k_pack
*
local_size_a
+
local_id
]
=
A_buf
[
A_base0
+
l
+
row
,
A_base1
+
r
+
col
]
else
:
else
:
for
i
in
T
.
serial
(
warp_rows
):
for
i
in
T
.
serial
(
warp_rows
):
for
local_id
in
T
.
vectorized
(
k_pack
*
local_size_a
):
for
local_id
in
T
.
vectorized
(
k_pack
*
local_size_a
):
row
,
col
=
T
.
meta_var
(
reverse_index_map
(
tx
,
local_id
))
row
,
col
=
T
.
meta_var
(
reverse_index_map
(
tx
,
local_id
))
l
,
r
=
(
warp_m
*
warp_row_tiles
+
i
*
micro_size_x
,
l
,
r
=
(
warp_m
*
warp_row_tiles
+
i
*
micro_size_x
,
rk
*
chunk
+
ki
*
(
k_pack
*
micro_size_k
))
rk
*
chunk
+
ki
*
(
k_pack
*
micro_size_k
))
A_local_buf
[
i
*
k_pack
*
local_size_a
+
local_id
]
=
A_buf
[
A_base0
+
l
+
row
,
A_base1
+
r
+
col
]
A_local_buf
[
i
*
k_pack
*
local_size_a
+
local_id
]
=
A_buf
[
A_base0
+
l
+
row
,
A_base1
+
r
+
col
]
return
_warp_ldmatrix_a
(
A_local_buf
,
A_shared_buf
,
ki
,
thread_binding
,
rk
)
return
_warp_ldmatrix_a
(
A_local_buf
,
A_shared_buf
,
ki
,
thread_binding
,
rk
)
...
@@ -337,8 +341,7 @@ class MatrixCoreIntrinEmitter:
...
@@ -337,8 +341,7 @@ class MatrixCoreIntrinEmitter:
warp_n
*
warp_col_tiles
+
j
*
micro_size_y
,
warp_n
*
warp_col_tiles
+
j
*
micro_size_y
,
rk
*
chunk
+
ki
*
(
k_pack
*
micro_size_k
),
rk
*
chunk
+
ki
*
(
k_pack
*
micro_size_k
),
)
)
B_local_buf
[
j
*
k_pack
*
local_size_b
+
local_id
]
=
B_buf
[
B_base0
+
l
+
row
,
B_local_buf
[
j
*
k_pack
*
local_size_b
+
local_id
]
=
B_buf
[
B_base0
+
l
+
row
,
B_base1
+
r
+
col
]
B_base1
+
r
+
col
]
else
:
else
:
for
j
in
T
.
serial
(
warp_cols
):
for
j
in
T
.
serial
(
warp_cols
):
...
@@ -348,16 +351,11 @@ class MatrixCoreIntrinEmitter:
...
@@ -348,16 +351,11 @@ class MatrixCoreIntrinEmitter:
rk
*
chunk
+
ki
*
(
k_pack
*
micro_size_k
),
rk
*
chunk
+
ki
*
(
k_pack
*
micro_size_k
),
warp_n
*
warp_col_tiles
+
j
*
micro_size_y
,
warp_n
*
warp_col_tiles
+
j
*
micro_size_y
,
)
)
B_local_buf
[
j
*
k_pack
*
local_size_b
+
local_id
]
=
B_buf
[
B_base0
+
l
+
row
,
B_local_buf
[
j
*
k_pack
*
local_size_b
+
local_id
]
=
B_buf
[
B_base0
+
l
+
row
,
B_base1
+
r
+
col
]
B_base1
+
r
+
col
]
return
_warp_ldmatrix_b
(
B_local_buf
,
B_shared_buf
,
ki
,
thread_binding
,
rk
)
return
_warp_ldmatrix_b
(
B_local_buf
,
B_shared_buf
,
ki
,
thread_binding
,
rk
)
def
mfma
(
self
,
def
mfma
(
self
,
A_local_buf
:
Buffer
,
B_local_buf
:
Buffer
,
C_local_buf
:
Buffer
,
k_inner
:
PrimExpr
|
None
=
0
):
A_local_buf
:
Buffer
,
B_local_buf
:
Buffer
,
C_local_buf
:
Buffer
,
k_inner
:
PrimExpr
|
None
=
0
):
warp_rows
=
self
.
warp_rows
warp_rows
=
self
.
warp_rows
warp_cols
=
self
.
warp_cols
warp_cols
=
self
.
warp_cols
local_size_a
=
self
.
local_size_a
local_size_a
=
self
.
local_size_a
...
@@ -421,14 +419,13 @@ class MatrixCoreIntrinEmitter:
...
@@ -421,14 +419,13 @@ class MatrixCoreIntrinEmitter:
for
local_id
in
T
.
vectorized
(
local_size_out
):
for
local_id
in
T
.
vectorized
(
local_size_out
):
row
,
col
=
T
.
meta_var
(
mfma_store_index_map
(
tx
,
local_id
))
row
,
col
=
T
.
meta_var
(
mfma_store_index_map
(
tx
,
local_id
))
if
C_buf_dims
==
2
:
if
C_buf_dims
==
2
:
C_buf
[(
warp_m
*
warp_rows
+
i
)
*
M_DIM
+
row
,
C_buf
[(
warp_m
*
warp_rows
+
i
)
*
M_DIM
+
row
,
(
warp_n
*
warp_cols
+
j
)
*
N_DIM
+
col
]
=
C_local_buf
[
(
warp_n
*
warp_cols
+
j
)
*
N_DIM
+
i
*
(
warp_cols
*
local_size_out
)
+
j
*
local_size_out
+
local_id
col
]
=
C_local_buf
[
i
*
(
warp_cols
*
local_size_out
)
+
]
j
*
local_size_out
+
local_id
]
else
:
else
:
C_buf
[
warp_m
*
warp_rows
+
i
,
warp_n
*
warp_cols
+
j
,
row
,
C_buf
[
warp_m
*
warp_rows
+
i
,
warp_n
*
warp_cols
+
j
,
row
,
col
]
=
C_local_buf
[
col
]
=
C_local_buf
[
i
*
warp_cols
*
local_size_out
+
i
*
warp_cols
*
local_size_out
+
j
*
local_size_out
+
local_id
j
*
local_size_out
+
local_id
]
]
@
T
.
macro
@
T
.
macro
def
_warp_stmatrix_global
(
C_local_buf
,
C_buf
,
thread_binding
):
def
_warp_stmatrix_global
(
C_local_buf
,
C_buf
,
thread_binding
):
...
@@ -436,18 +433,17 @@ class MatrixCoreIntrinEmitter:
...
@@ -436,18 +433,17 @@ class MatrixCoreIntrinEmitter:
for
i
,
j
in
T
.
grid
(
warp_rows
,
warp_cols
):
for
i
,
j
in
T
.
grid
(
warp_rows
,
warp_cols
):
for
local_id
in
T
.
vectorized
(
local_size_out
):
for
local_id
in
T
.
vectorized
(
local_size_out
):
row
,
col
=
T
.
meta_var
(
mfma_store_index_map
(
tx
,
local_id
))
row
,
col
=
T
.
meta_var
(
mfma_store_index_map
(
tx
,
local_id
))
C_buf
[(
pid_m
*
BLOCK_M
+
warp_m
*
warp_rows
+
i
)
*
M_DIM
+
row
,
C_buf
[
(
pid_n
*
BLOCK_N
+
warp_n
*
warp_cols
+
j
)
*
N_DIM
+
(
pid_m
*
BLOCK_M
+
warp_m
*
warp_rows
+
i
)
*
M_DIM
+
row
,
(
pid_n
*
BLOCK_N
+
warp_n
*
warp_cols
+
j
)
*
N_DIM
+
col
col
]
=
C_local_buf
[
i
*
warp_cols
*
local_size_out
+
j
*
local_size_out
+
]
=
C_local_buf
[
i
*
warp_cols
*
local_size_out
+
j
*
local_size_out
+
local_id
]
local_id
]
return
(
return
_warp_stmatrix_global
(
C_local_buf
,
C_buf
,
_warp_stmatrix_global
(
C_local_buf
,
C_buf
,
thread_binding
)
thread_binding
)
if
is_global
else
_warp_stmatrix_shared
(
if
is_global
C_local_buf
,
C_buf
,
thread_binding
)
else
_warp_stmatrix_shared
(
C_local_buf
,
C_buf
,
thread_binding
)
)
def
make_mfma_load_layout
(
self
,
local_buf
:
Buffer
,
def
make_mfma_load_layout
(
self
,
local_buf
:
Buffer
,
matrix
:
Literal
[
"A"
,
"B"
]
=
"A"
)
->
T
.
Fragment
:
matrix
:
Literal
[
"A"
,
"B"
]
=
"A"
)
->
T
.
Fragment
:
"""
"""
Create a layout function for storing MFMA results into a fragment buffer.
Create a layout function for storing MFMA results into a fragment buffer.
...
@@ -468,6 +464,7 @@ class MatrixCoreIntrinEmitter:
...
@@ -468,6 +464,7 @@ class MatrixCoreIntrinEmitter:
If `local_buf` is not detected to be a fragment buffer.
If `local_buf` is not detected to be a fragment buffer.
"""
"""
from
tilelang.utils
import
is_fragment
from
tilelang.utils
import
is_fragment
assert
matrix
in
[
"A"
,
"B"
],
"matrix should be either A or B"
assert
matrix
in
[
"A"
,
"B"
],
"matrix should be either A or B"
matrix_is_a
:
bool
=
matrix
==
"A"
matrix_is_a
:
bool
=
matrix
==
"A"
matrix_is_b
:
bool
=
matrix
==
"B"
matrix_is_b
:
bool
=
matrix
==
"B"
...
@@ -506,11 +503,9 @@ class MatrixCoreIntrinEmitter:
...
@@ -506,11 +503,9 @@ class MatrixCoreIntrinEmitter:
transform_func
:
Callable
=
None
transform_func
:
Callable
=
None
if
matrix_is_a
:
if
matrix_is_a
:
transform_func
=
transform_func_sr_a
if
is_sr_axis_order
else
lambda
i
,
j
:
transform_func_sr_a
(
transform_func
=
transform_func_sr_a
if
is_sr_axis_order
else
lambda
i
,
j
:
transform_func_sr_a
(
j
,
i
)
j
,
i
)
elif
matrix_is_b
:
elif
matrix_is_b
:
transform_func
=
transform_func_sr_b
if
is_sr_axis_order
else
lambda
i
,
j
:
transform_func_sr_b
(
transform_func
=
transform_func_sr_b
if
is_sr_axis_order
else
lambda
i
,
j
:
transform_func_sr_b
(
j
,
i
)
j
,
i
)
else
:
else
:
raise
ValueError
(
f
"Unsupported matrix
{
matrix
}
"
)
raise
ValueError
(
f
"Unsupported matrix
{
matrix
}
"
)
...
@@ -543,8 +538,7 @@ class MatrixCoreIntrinEmitter:
...
@@ -543,8 +538,7 @@ class MatrixCoreIntrinEmitter:
return
local_id
return
local_id
base_fragment
=
T
.
Fragment
(
base_fragment
=
T
.
Fragment
(
[
micro_size_s
,
micro_size_r
*
[
micro_size_s
,
micro_size_r
*
self
.
k_pack
]
if
is_sr_axis_order
else
[
micro_size_r
*
self
.
k_pack
,
micro_size_s
],
self
.
k_pack
]
if
is_sr_axis_order
else
[
micro_size_r
*
self
.
k_pack
,
micro_size_s
],
forward_thread_fn
=
forward_thread
,
forward_thread_fn
=
forward_thread
,
forward_index_fn
=
forward_index
,
forward_index_fn
=
forward_index
,
)
)
...
@@ -558,31 +552,19 @@ class MatrixCoreIntrinEmitter:
...
@@ -558,31 +552,19 @@ class MatrixCoreIntrinEmitter:
replicate
=
block_col_warps
if
matrix_is_a
else
block_row_warps
replicate
=
block_col_warps
if
matrix_is_a
else
block_row_warps
if
is_sr_axis_order
:
if
is_sr_axis_order
:
warp_fragment
=
base_fragment
.
repeat
([
warp_s
,
warp_r
],
warp_fragment
=
base_fragment
.
repeat
([
warp_s
,
warp_r
],
repeat_on_thread
=
False
,
lower_dim_first
=
False
)
repeat_on_thread
=
False
,
lower_dim_first
=
False
)
if
matrix_is_a
:
if
matrix_is_a
:
block_fragment
=
warp_fragment
.
repeat
([
block_s
,
1
],
block_fragment
=
warp_fragment
.
repeat
([
block_s
,
1
],
repeat_on_thread
=
True
,
lower_dim_first
=
True
).
replicate
(
replicate
)
repeat_on_thread
=
True
,
lower_dim_first
=
True
).
replicate
(
replicate
)
elif
matrix_is_b
:
elif
matrix_is_b
:
block_fragment
=
warp_fragment
.
replicate
(
replicate
).
repeat
([
block_s
,
1
],
block_fragment
=
warp_fragment
.
replicate
(
replicate
).
repeat
([
block_s
,
1
],
repeat_on_thread
=
True
,
lower_dim_first
=
True
)
repeat_on_thread
=
True
,
lower_dim_first
=
True
)
else
:
else
:
raise
ValueError
(
f
"Unsupported matrix type
{
matrix
}
"
)
raise
ValueError
(
f
"Unsupported matrix type
{
matrix
}
"
)
else
:
else
:
warp_fragment
=
base_fragment
.
repeat
([
warp_r
,
warp_s
],
warp_fragment
=
base_fragment
.
repeat
([
warp_r
,
warp_s
],
repeat_on_thread
=
False
,
lower_dim_first
=
True
)
repeat_on_thread
=
False
,
lower_dim_first
=
True
)
if
matrix_is_a
:
if
matrix_is_a
:
block_fragment
=
warp_fragment
.
repeat
([
1
,
block_s
],
block_fragment
=
warp_fragment
.
repeat
([
1
,
block_s
],
repeat_on_thread
=
True
,
lower_dim_first
=
True
).
replicate
(
replicate
)
repeat_on_thread
=
True
,
lower_dim_first
=
True
).
replicate
(
replicate
)
elif
matrix_is_b
:
elif
matrix_is_b
:
block_fragment
=
warp_fragment
.
replicate
(
replicate
).
repeat
([
1
,
block_s
],
block_fragment
=
warp_fragment
.
replicate
(
replicate
).
repeat
([
1
,
block_s
],
repeat_on_thread
=
True
,
lower_dim_first
=
True
)
repeat_on_thread
=
True
,
lower_dim_first
=
True
)
else
:
else
:
raise
ValueError
(
f
"Unsupported matrix type
{
matrix
}
"
)
raise
ValueError
(
f
"Unsupported matrix type
{
matrix
}
"
)
...
@@ -686,7 +668,6 @@ class MatrixCoreIntrinEmitter:
...
@@ -686,7 +668,6 @@ class MatrixCoreIntrinEmitter:
class
MatrixCorePreshuffleIntrinEmitter
(
MatrixCoreIntrinEmitter
):
class
MatrixCorePreshuffleIntrinEmitter
(
MatrixCoreIntrinEmitter
):
def
__init__
(
def
__init__
(
self
,
self
,
a_dtype
:
str
=
"float16"
,
a_dtype
:
str
=
"float16"
,
...
@@ -792,20 +773,20 @@ class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter):
...
@@ -792,20 +773,20 @@ class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter):
rk
*
(
chunk
//
micro_size_k
)
+
ki
,
rk
*
(
chunk
//
micro_size_k
)
+
ki
,
warp_m
*
warp_rows
+
i
,
warp_m
*
warp_rows
+
i
,
)
)
A_local_buf
[
i
*
k_pack
*
local_size_a
+
local_id
]
=
A_shared_buf
[
l
,
r
,
row
,
A_local_buf
[
i
*
k_pack
*
local_size_a
+
local_id
]
=
A_shared_buf
[
l
,
r
,
row
,
col
]
col
]
else
:
else
:
print
(
self
.
a_preshuffle
)
print
(
self
.
a_preshuffle
)
for
i
in
T
.
serial
(
warp_rows
):
for
i
in
T
.
serial
(
warp_rows
):
for
local_id
in
T
.
vectorized
(
k_pack
*
local_size_a
):
for
local_id
in
T
.
vectorized
(
k_pack
*
local_size_a
):
row
,
col
=
T
.
meta_var
(
reverse_index_map
(
tx
,
local_id
))
row
,
col
=
T
.
meta_var
(
reverse_index_map
(
tx
,
local_id
))
l
,
r
=
(
warp_m
*
warp_rows
+
i
,
rk
*
(
chunk
//
micro_size_k
)
+
ki
)
l
,
r
=
(
warp_m
*
warp_rows
+
i
,
rk
*
(
chunk
//
micro_size_k
)
+
ki
)
A_local_buf
[
i
*
k_pack
*
local_size_a
+
local_id
]
=
A_shared_buf
[
l
,
r
,
row
,
A_local_buf
[
i
*
k_pack
*
local_size_a
+
local_id
]
=
A_shared_buf
[
l
,
r
,
row
,
col
]
col
]
return
_warp_ldmatrix_a_global
(
A_local_buf
,
A_buf
,
ki
,
thread_binding
,
return
(
rk
)
if
is_global
else
_warp_ldmatrix_a_shared
(
_warp_ldmatrix_a_global
(
A_local_buf
,
A_buf
,
ki
,
thread_binding
,
rk
)
A_local_buf
,
A_buf
,
ki
,
thread_binding
,
rk
)
if
is_global
else
_warp_ldmatrix_a_shared
(
A_local_buf
,
A_buf
,
ki
,
thread_binding
,
rk
)
)
def
ldmatrix_b
(
self
,
B_local_buf
,
B_buf
,
ki
,
rk
=
0
,
pid_m
=
None
,
pid_n
=
None
):
def
ldmatrix_b
(
self
,
B_local_buf
,
B_buf
,
ki
,
rk
=
0
,
pid_m
=
None
,
pid_n
=
None
):
warp_cols
=
self
.
warp_cols
warp_cols
=
self
.
warp_cols
...
@@ -867,8 +848,7 @@ class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter):
...
@@ -867,8 +848,7 @@ class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter):
warp_n
*
warp_cols
+
j
,
warp_n
*
warp_cols
+
j
,
rk
*
(
chunk
//
micro_size_k
)
+
ki
,
rk
*
(
chunk
//
micro_size_k
)
+
ki
,
)
)
B_local_buf
[
j
*
k_pack
*
local_size_b
+
local_id
]
=
B_shared_buf
[
l
,
r
,
row
,
B_local_buf
[
j
*
k_pack
*
local_size_b
+
local_id
]
=
B_shared_buf
[
l
,
r
,
row
,
col
]
col
]
else
:
else
:
for
j
in
T
.
serial
(
warp_cols
):
for
j
in
T
.
serial
(
warp_cols
):
for
local_id
in
T
.
vectorized
(
k_pack
*
local_size_b
):
for
local_id
in
T
.
vectorized
(
k_pack
*
local_size_b
):
...
@@ -877,9 +857,10 @@ class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter):
...
@@ -877,9 +857,10 @@ class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter):
rk
*
(
chunk
//
micro_size_k
)
+
ki
,
rk
*
(
chunk
//
micro_size_k
)
+
ki
,
warp_n
*
warp_cols
+
j
,
warp_n
*
warp_cols
+
j
,
)
)
B_local_buf
[
j
*
k_pack
*
local_size_b
+
local_id
]
=
B_shared_buf
[
l
,
r
,
row
,
B_local_buf
[
j
*
k_pack
*
local_size_b
+
local_id
]
=
B_shared_buf
[
l
,
r
,
row
,
col
]
col
]
return
_warp_ldmatrix_b_global
(
B_local_buf
,
B_buf
,
ki
,
thread_binding
,
return
(
rk
)
if
is_global
else
_warp_ldmatrix_b_shared
(
_warp_ldmatrix_b_global
(
B_local_buf
,
B_buf
,
ki
,
thread_binding
,
rk
)
B_local_buf
,
B_buf
,
ki
,
thread_binding
,
rk
)
if
is_global
else
_warp_ldmatrix_b_shared
(
B_local_buf
,
B_buf
,
ki
,
thread_binding
,
rk
)
)
tilelang/intrinsics/mma_layout.py
View file @
29051439
...
@@ -153,14 +153,14 @@ def mma_load_a_32x16_to_shared_16x32_layout(thread_id, local_id):
...
@@ -153,14 +153,14 @@ def mma_load_a_32x16_to_shared_16x32_layout(thread_id, local_id):
def
mma_load_a_32x8_to_shared_16x16_layout
(
thread_id
,
local_id
):
def
mma_load_a_32x8_to_shared_16x16_layout
(
thread_id
,
local_id
):
"""
"""
groupID = %laneid >> 2
groupID = %laneid >> 2
threadID_in_group = %laneid % 4
threadID_in_group = %laneid % 4
row = groupID for ai where 0 <= i < 2 || 4 <= i < 6
row = groupID for ai where 0 <= i < 2 || 4 <= i < 6
groupID + 8 Otherwise
groupID + 8 Otherwise
col = (threadID_in_group * 2) + (i & 0x1) for ai where i < 4
col = (threadID_in_group * 2) + (i & 0x1) for ai where i < 4
(threadID_in_group * 2) + (i & 0x1) + 8 for ai where i >= 4
(threadID_in_group * 2) + (i & 0x1) + 8 for ai where i >= 4
"""
"""
row
=
(
thread_id
//
4
)
+
8
*
(
local_id
%
4
//
2
)
row
=
(
thread_id
//
4
)
+
8
*
(
local_id
%
4
//
2
)
col
=
(
thread_id
%
4
)
*
2
+
(
local_id
%
2
)
+
8
*
(
local_id
//
4
)
col
=
(
thread_id
%
4
)
*
2
+
(
local_id
%
2
)
+
8
*
(
local_id
//
4
)
...
@@ -175,13 +175,13 @@ def mma_load_b_32x16_to_shared_16x32_layout(thread_id, local_id):
...
@@ -175,13 +175,13 @@ def mma_load_b_32x16_to_shared_16x32_layout(thread_id, local_id):
def
mma_load_b_32x8_to_shared_16x16_layout
(
thread_id
,
local_id
):
def
mma_load_b_32x8_to_shared_16x16_layout
(
thread_id
,
local_id
):
"""
"""
groupID = %laneid >> 2
groupID = %laneid >> 2
threadID_in_group = %laneid % 4
threadID_in_group = %laneid % 4
row = (threadID_in_group * 2) + (i & 0x1) for bi where i < 2
row = (threadID_in_group * 2) + (i & 0x1) for bi where i < 2
(threadID_in_group * 2) + (i & 0x1) + 8 for bi where i >= 2
(threadID_in_group * 2) + (i & 0x1) + 8 for bi where i >= 2
col = groupID
col = groupID
"""
"""
col
=
(
thread_id
%
4
)
*
2
+
((
local_id
%
4
)
%
2
)
+
((
local_id
%
4
)
//
2
)
*
8
col
=
(
thread_id
%
4
)
*
2
+
((
local_id
%
4
)
%
2
)
+
((
local_id
%
4
)
//
2
)
*
8
row
=
(
thread_id
//
4
)
+
8
*
(
local_id
//
4
)
row
=
(
thread_id
//
4
)
+
8
*
(
local_id
//
4
)
...
...
tilelang/intrinsics/mma_macro_generator.py
View file @
29051439
...
@@ -191,6 +191,7 @@ class TensorCoreIntrinEmitter:
...
@@ -191,6 +191,7 @@ class TensorCoreIntrinEmitter:
def
get_store_index_map
(
self
,
inverse
:
bool
=
False
)
->
IndexMap
:
def
get_store_index_map
(
self
,
inverse
:
bool
=
False
)
->
IndexMap
:
from
.utils
import
mma_store_index_map
,
mma_store_index_map_fp64
from
.utils
import
mma_store_index_map
,
mma_store_index_map_fp64
warp_size
,
local_size_c
=
self
.
WARP_SIZE
,
self
.
local_size_out
warp_size
,
local_size_c
=
self
.
WARP_SIZE
,
self
.
local_size_out
if
DataType
(
self
.
accum_dtype
).
bits
==
64
:
if
DataType
(
self
.
accum_dtype
).
bits
==
64
:
index_map
=
IndexMap
.
from_func
(
mma_store_index_map_fp64
,
index_dtype
=
"int32"
)
index_map
=
IndexMap
.
from_func
(
mma_store_index_map_fp64
,
index_dtype
=
"int32"
)
...
@@ -201,10 +202,7 @@ class TensorCoreIntrinEmitter:
...
@@ -201,10 +202,7 @@ class TensorCoreIntrinEmitter:
inverse_index_map
=
index_map
.
inverse
([
warp_size
,
local_size_c
])
inverse_index_map
=
index_map
.
inverse
([
warp_size
,
local_size_c
])
return
inverse_index_map
return
inverse_index_map
def
extract_thread_binding
(
def
extract_thread_binding
(
self
,
thread_id
:
PrimExpr
,
is_m_first
:
bool
|
None
=
None
)
->
tuple
[
PrimExpr
,
PrimExpr
,
PrimExpr
]:
self
,
thread_id
:
PrimExpr
,
is_m_first
:
bool
|
None
=
None
)
->
tuple
[
PrimExpr
,
PrimExpr
,
PrimExpr
]:
"""
"""
is_m_first: True if the thread binding is in the form of (tx, warp_n, warp_m)
is_m_first: True if the thread binding is in the form of (tx, warp_n, warp_m)
which represents [warp_size, block_row_warps (split n), block_col_warps (split m)]
which represents [warp_size, block_row_warps (split n), block_col_warps (split m)]
...
@@ -233,11 +231,7 @@ class TensorCoreIntrinEmitter:
...
@@ -233,11 +231,7 @@ class TensorCoreIntrinEmitter:
)
)
return
lane_id
,
warp_n
,
warp_m
return
lane_id
,
warp_n
,
warp_m
def
ldmatrix_a
(
self
,
def
ldmatrix_a
(
self
,
A_local_buf
:
Buffer
,
A_shared_buf
:
Buffer
|
BufferRegion
,
ki
:
PrimExpr
,
rk
:
PrimExpr
|
None
=
0
):
A_local_buf
:
Buffer
,
A_shared_buf
:
Buffer
|
BufferRegion
,
ki
:
PrimExpr
,
rk
:
PrimExpr
|
None
=
0
):
# Fast path for fp64: no ldmatrix support, do direct per-lane loads
# Fast path for fp64: no ldmatrix support, do direct per-lane loads
if
DataType
(
self
.
a_dtype
).
bits
==
64
:
if
DataType
(
self
.
a_dtype
).
bits
==
64
:
warp_row_tiles
=
self
.
warp_row_tiles
warp_row_tiles
=
self
.
warp_row_tiles
...
@@ -324,9 +318,7 @@ class TensorCoreIntrinEmitter:
...
@@ -324,9 +318,7 @@ class TensorCoreIntrinEmitter:
for
i
in
T
.
serial
(
warp_rows
):
for
i
in
T
.
serial
(
warp_rows
):
# Assign A_shared_buf_elem
# Assign A_shared_buf_elem
wi
,
wk
=
warp_m
*
warp_row_tiles
+
i
*
micro_size_x
,
rk
*
chunk
+
ki
*
micro_size_k
wi
,
wk
=
warp_m
*
warp_row_tiles
+
i
*
micro_size_x
,
rk
*
chunk
+
ki
*
micro_size_k
A_shared_buf_elem
=
A_buf
[
A_base0
+
wk
,
A_shared_buf_elem
=
A_buf
[
A_base0
+
wk
,
A_base1
+
wi
]
if
a_transposed
else
A_buf
[
A_base0
+
wi
,
A_base1
+
wk
]
A_base1
+
wi
]
if
a_transposed
else
A_buf
[
A_base0
+
wi
,
A_base1
+
wk
]
if
ldmatrix_available
:
if
ldmatrix_available
:
T
.
ptx_ldmatrix
(
T
.
ptx_ldmatrix
(
...
@@ -343,20 +335,13 @@ class TensorCoreIntrinEmitter:
...
@@ -343,20 +335,13 @@ class TensorCoreIntrinEmitter:
for
j
in
T
.
serial
(
local_size_a
):
for
j
in
T
.
serial
(
local_size_a
):
mi
,
mk
=
mma_load_layout
(
tx
,
j
)
mi
,
mk
=
mma_load_layout
(
tx
,
j
)
if
a_transposed
:
if
a_transposed
:
A_local_buf
[
i
*
local_size_a
+
j
]
=
A_buf
[
A_base0
+
wk
+
mk
,
A_local_buf
[
i
*
local_size_a
+
j
]
=
A_buf
[
A_base0
+
wk
+
mk
,
A_base1
+
wi
+
mi
]
A_base1
+
wi
+
mi
]
else
:
else
:
A_local_buf
[
i
*
local_size_a
+
j
]
=
A_buf
[
A_base0
+
wi
+
mi
,
A_local_buf
[
i
*
local_size_a
+
j
]
=
A_buf
[
A_base0
+
wi
+
mi
,
A_base1
+
wk
+
mk
]
A_base1
+
wk
+
mk
]
return
_warp_ldmatrix_a
(
A_local_buf
,
A_region
,
ki
,
thread_binding
,
rk
)
return
_warp_ldmatrix_a
(
A_local_buf
,
A_region
,
ki
,
thread_binding
,
rk
)
def
ldmatrix_b
(
self
,
def
ldmatrix_b
(
self
,
B_local_buf
:
Buffer
,
B_shared_buf
:
Buffer
|
BufferRegion
,
ki
:
PrimExpr
,
rk
:
PrimExpr
|
None
=
0
):
B_local_buf
:
Buffer
,
B_shared_buf
:
Buffer
|
BufferRegion
,
ki
:
PrimExpr
,
rk
:
PrimExpr
|
None
=
0
):
# Fast path for fp64: no ldmatrix support, do direct per-lane loads
# Fast path for fp64: no ldmatrix support, do direct per-lane loads
if
DataType
(
self
.
b_dtype
).
bits
==
64
:
if
DataType
(
self
.
b_dtype
).
bits
==
64
:
warp_col_tiles
=
self
.
warp_col_tiles
warp_col_tiles
=
self
.
warp_col_tiles
...
@@ -411,7 +396,7 @@ class TensorCoreIntrinEmitter:
...
@@ -411,7 +396,7 @@ class TensorCoreIntrinEmitter:
B_base0
=
B_region
.
region
[
-
2
].
min
B_base0
=
B_region
.
region
[
-
2
].
min
B_base1
=
B_region
.
region
[
-
1
].
min
B_base1
=
B_region
.
region
[
-
1
].
min
B_stride_last
=
B_buf
.
shape
[
-
1
]
B_stride_last
=
B_buf
.
shape
[
-
1
]
replicate_b
=
(
self
.
n_dim
==
16
)
replicate_b
=
self
.
n_dim
==
16
# ldmatrix cannot be used for int8 + trans case.
# ldmatrix cannot be used for int8 + trans case.
ldmatrix_available
=
not
(
DataType
(
b_dtype
).
bits
!=
16
and
not
b_transposed
)
ldmatrix_available
=
not
(
DataType
(
b_dtype
).
bits
!=
16
and
not
b_transposed
)
...
@@ -448,9 +433,7 @@ class TensorCoreIntrinEmitter:
...
@@ -448,9 +433,7 @@ class TensorCoreIntrinEmitter:
)
)
if
ldmatrix_available
:
if
ldmatrix_available
:
B_shared_buf_elem
=
B_buf
[
B_base0
+
wi
,
B_shared_buf_elem
=
B_buf
[
B_base0
+
wi
,
B_base1
+
wk
]
if
b_transposed
else
B_buf
[
B_base0
+
wk
,
B_base1
+
wi
]
B_base1
+
wk
]
if
b_transposed
else
B_buf
[
B_base0
+
wk
,
B_base1
+
wi
]
T
.
ptx_ldmatrix
(
T
.
ptx_ldmatrix
(
b_dtype
,
b_dtype
,
...
@@ -469,19 +452,13 @@ class TensorCoreIntrinEmitter:
...
@@ -469,19 +452,13 @@ class TensorCoreIntrinEmitter:
for
j
in
T
.
serial
(
local_size_b
):
for
j
in
T
.
serial
(
local_size_b
):
mi
,
mk
=
mma_load_layout
(
tx
,
j
)
mi
,
mk
=
mma_load_layout
(
tx
,
j
)
if
b_transposed
:
if
b_transposed
:
B_local_buf
[
i
*
local_size_b
+
j
]
=
B_buf
[
B_base0
+
wi
+
mi
,
B_local_buf
[
i
*
local_size_b
+
j
]
=
B_buf
[
B_base0
+
wi
+
mi
,
B_base1
+
wk
+
mk
]
B_base1
+
wk
+
mk
]
else
:
else
:
B_local_buf
[
i
*
local_size_b
+
j
]
=
B_buf
[
B_base0
+
wk
+
mk
,
B_local_buf
[
i
*
local_size_b
+
j
]
=
B_buf
[
B_base0
+
wk
+
mk
,
B_base1
+
wi
+
mi
]
B_base1
+
wi
+
mi
]
return
_warp_ldmatrix_b
(
B_local_buf
,
B_shared_buf
,
ki
,
thread_binding
,
rk
)
return
_warp_ldmatrix_b
(
B_local_buf
,
B_shared_buf
,
ki
,
thread_binding
,
rk
)
def
mma
(
self
,
def
mma
(
self
,
A_local_buf
:
Buffer
,
B_local_buf
:
Buffer
,
C_local_buf
:
Buffer
,
k_inner
:
PrimExpr
|
None
=
0
):
A_local_buf
:
Buffer
,
B_local_buf
:
Buffer
,
C_local_buf
:
Buffer
,
k_inner
:
PrimExpr
|
None
=
0
):
warp_rows
=
self
.
warp_rows
warp_rows
=
self
.
warp_rows
warp_cols
=
self
.
warp_cols
warp_cols
=
self
.
warp_cols
local_size_a
=
self
.
local_size_a
local_size_a
=
self
.
local_size_a
...
@@ -492,7 +469,7 @@ class TensorCoreIntrinEmitter:
...
@@ -492,7 +469,7 @@ class TensorCoreIntrinEmitter:
accum_dtype
=
self
.
accum_dtype
accum_dtype
=
self
.
accum_dtype
accum_dtype_abbrv
=
self
.
accum_dtype_abbrv
accum_dtype_abbrv
=
self
.
accum_dtype_abbrv
mma_prefix
=
self
.
mma_prefix
mma_prefix
=
self
.
mma_prefix
replicate_b
=
(
self
.
n_dim
==
16
)
replicate_b
=
self
.
n_dim
==
16
a_is_fragment
=
is_fragment
(
A_local_buf
)
a_is_fragment
=
is_fragment
(
A_local_buf
)
b_is_fragment
=
is_fragment
(
B_local_buf
)
b_is_fragment
=
is_fragment
(
B_local_buf
)
...
@@ -532,8 +509,7 @@ class TensorCoreIntrinEmitter:
...
@@ -532,8 +509,7 @@ class TensorCoreIntrinEmitter:
B_local_buf
.
data
,
B_local_buf
.
data
,
b_local_stride
+
j
*
local_size_b
+
lift
(
local_size_b
)
//
2
,
b_local_stride
+
j
*
local_size_b
+
lift
(
local_size_b
)
//
2
,
C_local_buf
.
data
,
C_local_buf
.
data
,
i
*
warp_cols
*
local_size_out
+
j
*
local_size_out
+
i
*
warp_cols
*
local_size_out
+
j
*
local_size_out
+
lift
(
local_size_out
)
//
2
,
lift
(
local_size_out
)
//
2
,
T
.
bool
(
False
),
# saturate
T
.
bool
(
False
),
# saturate
)
)
...
@@ -568,14 +544,13 @@ class TensorCoreIntrinEmitter:
...
@@ -568,14 +544,13 @@ class TensorCoreIntrinEmitter:
local_id
=
local_id_o
*
2
+
local_id_i
local_id
=
local_id_o
*
2
+
local_id_i
row
,
col
=
T
.
meta_var
(
mma_store_index_map
(
tx
,
local_id
))
row
,
col
=
T
.
meta_var
(
mma_store_index_map
(
tx
,
local_id
))
if
C_buf_dims
==
2
:
if
C_buf_dims
==
2
:
C_buf
[(
warp_m
*
warp_rows
+
i
)
*
M_DIM
+
row
,
C_buf
[(
warp_m
*
warp_rows
+
i
)
*
M_DIM
+
row
,
(
warp_n
*
warp_cols
+
j
)
*
n_dim
+
col
]
=
C_local_buf
[
(
warp_n
*
warp_cols
+
j
)
*
n_dim
+
i
*
(
warp_cols
*
local_size_out
)
+
j
*
local_size_out
+
local_id
col
]
=
C_local_buf
[
i
*
(
warp_cols
*
local_size_out
)
+
]
j
*
local_size_out
+
local_id
]
else
:
else
:
C_buf
[
warp_m
*
warp_rows
+
i
,
warp_n
*
warp_cols
+
j
,
row
,
C_buf
[
warp_m
*
warp_rows
+
i
,
warp_n
*
warp_cols
+
j
,
row
,
col
]
=
C_local_buf
[
col
]
=
C_local_buf
[
i
*
(
warp_cols
*
local_size_out
)
+
i
*
(
warp_cols
*
local_size_out
)
+
j
*
local_size_out
+
local_id
j
*
local_size_out
+
local_id
]
]
@
T
.
macro
@
T
.
macro
def
_warp_stmatrix_global
(
C_local_buf
,
C_buf
,
thread_binding
):
def
_warp_stmatrix_global
(
C_local_buf
,
C_buf
,
thread_binding
):
...
@@ -588,15 +563,15 @@ class TensorCoreIntrinEmitter:
...
@@ -588,15 +563,15 @@ class TensorCoreIntrinEmitter:
C_buf
[
C_buf
[
(
pid_m
*
BLOCK_M
+
warp_m
*
warp_rows
+
i
)
*
M_DIM
+
row
,
(
pid_m
*
BLOCK_M
+
warp_m
*
warp_rows
+
i
)
*
M_DIM
+
row
,
(
pid_n
*
BLOCK_N
+
warp_n
*
warp_cols
+
j
)
*
n_dim
+
col
,
(
pid_n
*
BLOCK_N
+
warp_n
*
warp_cols
+
j
)
*
n_dim
+
col
,
]
=
C_local_buf
[
i
*
warp_cols
*
local_size_out
+
j
*
local_size_out
+
]
=
C_local_buf
[
i
*
warp_cols
*
local_size_out
+
j
*
local_size_out
+
local_id
]
local_id
]
return
(
_warp_stmatrix_global
(
C_local_buf
,
C_buf
,
thread_binding
)
return
(
if
is_global
else
_warp_stmatrix_shared
(
C_local_buf
,
C_buf
,
thread_binding
))
_warp_stmatrix_global
(
C_local_buf
,
C_buf
,
thread_binding
)
if
is_global
else
_warp_stmatrix_shared
(
C_local_buf
,
C_buf
,
thread_binding
)
)
def
make_mma_load_layout
(
self
,
def
make_mma_load_layout
(
self
,
local_buf
:
Buffer
,
matrix
:
Literal
[
"A"
,
"B"
]
=
"A"
)
->
T
.
Fragment
:
local_buf
:
Buffer
,
matrix
:
Literal
[
"A"
,
"B"
]
=
"A"
)
->
T
.
Fragment
:
"""
"""
Create a layout function for storing MMA results into a fragment buffer.
Create a layout function for storing MMA results into a fragment buffer.
This layout is used in conjunction with `inverse_mma_store_layout` to
This layout is used in conjunction with `inverse_mma_store_layout` to
...
@@ -619,6 +594,7 @@ class TensorCoreIntrinEmitter:
...
@@ -619,6 +594,7 @@ class TensorCoreIntrinEmitter:
If `local_buf` is not detected to be a fragment buffer.
If `local_buf` is not detected to be a fragment buffer.
"""
"""
from
tilelang.utils
import
is_fragment
from
tilelang.utils
import
is_fragment
assert
matrix
in
[
"A"
,
"B"
],
"matrix should be either A or B"
assert
matrix
in
[
"A"
,
"B"
],
"matrix should be either A or B"
matrix_is_a
:
bool
=
matrix
==
"A"
matrix_is_a
:
bool
=
matrix
==
"A"
matrix_is_b
:
bool
=
matrix
==
"B"
matrix_is_b
:
bool
=
matrix
==
"B"
...
@@ -655,11 +631,9 @@ class TensorCoreIntrinEmitter:
...
@@ -655,11 +631,9 @@ class TensorCoreIntrinEmitter:
# so the b matrix expected a transposed basic layout
# so the b matrix expected a transposed basic layout
transform_func
:
Callable
=
None
transform_func
:
Callable
=
None
if
matrix_is_a
:
if
matrix_is_a
:
transform_func
=
transform_func_sr_a
if
is_sr_axis_order
else
lambda
i
,
j
:
transform_func_sr_a
(
transform_func
=
transform_func_sr_a
if
is_sr_axis_order
else
lambda
i
,
j
:
transform_func_sr_a
(
j
,
i
)
j
,
i
)
elif
matrix_is_b
:
elif
matrix_is_b
:
transform_func
=
transform_func_sr_b
if
is_sr_axis_order
else
lambda
i
,
j
:
transform_func_sr_b
(
transform_func
=
transform_func_sr_b
if
is_sr_axis_order
else
lambda
i
,
j
:
transform_func_sr_b
(
j
,
i
)
j
,
i
)
else
:
else
:
raise
ValueError
(
f
"Unsupported matrix
{
matrix
}
"
)
raise
ValueError
(
f
"Unsupported matrix
{
matrix
}
"
)
...
@@ -706,31 +680,19 @@ class TensorCoreIntrinEmitter:
...
@@ -706,31 +680,19 @@ class TensorCoreIntrinEmitter:
replicate
=
block_col_warps
if
matrix_is_a
else
block_row_warps
replicate
=
block_col_warps
if
matrix_is_a
else
block_row_warps
if
is_sr_axis_order
:
if
is_sr_axis_order
:
warp_fragment
=
base_fragment
.
repeat
([
warp_s
,
warp_r
],
warp_fragment
=
base_fragment
.
repeat
([
warp_s
,
warp_r
],
repeat_on_thread
=
False
,
lower_dim_first
=
False
)
repeat_on_thread
=
False
,
lower_dim_first
=
False
)
if
matrix_is_a
:
if
matrix_is_a
:
block_fragment
=
warp_fragment
.
repeat
([
block_s
,
1
],
block_fragment
=
warp_fragment
.
repeat
([
block_s
,
1
],
repeat_on_thread
=
True
,
lower_dim_first
=
True
).
replicate
(
replicate
)
repeat_on_thread
=
True
,
lower_dim_first
=
True
).
replicate
(
replicate
)
elif
matrix_is_b
:
elif
matrix_is_b
:
block_fragment
=
warp_fragment
.
replicate
(
replicate
).
repeat
([
block_s
,
1
],
block_fragment
=
warp_fragment
.
replicate
(
replicate
).
repeat
([
block_s
,
1
],
repeat_on_thread
=
True
,
lower_dim_first
=
True
)
repeat_on_thread
=
True
,
lower_dim_first
=
True
)
else
:
else
:
raise
ValueError
(
f
"Unsupported matrix type
{
matrix
}
"
)
raise
ValueError
(
f
"Unsupported matrix type
{
matrix
}
"
)
else
:
else
:
warp_fragment
=
base_fragment
.
repeat
([
warp_r
,
warp_s
],
warp_fragment
=
base_fragment
.
repeat
([
warp_r
,
warp_s
],
repeat_on_thread
=
False
,
lower_dim_first
=
True
)
repeat_on_thread
=
False
,
lower_dim_first
=
True
)
if
matrix_is_a
:
if
matrix_is_a
:
block_fragment
=
warp_fragment
.
repeat
([
1
,
block_s
],
block_fragment
=
warp_fragment
.
repeat
([
1
,
block_s
],
repeat_on_thread
=
True
,
lower_dim_first
=
True
).
replicate
(
replicate
)
repeat_on_thread
=
True
,
lower_dim_first
=
True
).
replicate
(
replicate
)
elif
matrix_is_b
:
elif
matrix_is_b
:
block_fragment
=
warp_fragment
.
replicate
(
replicate
).
repeat
([
1
,
block_s
],
block_fragment
=
warp_fragment
.
replicate
(
replicate
).
repeat
([
1
,
block_s
],
repeat_on_thread
=
True
,
lower_dim_first
=
True
)
repeat_on_thread
=
True
,
lower_dim_first
=
True
)
else
:
else
:
raise
ValueError
(
f
"Unsupported matrix type
{
matrix
}
"
)
raise
ValueError
(
f
"Unsupported matrix type
{
matrix
}
"
)
...
@@ -761,8 +723,7 @@ class TensorCoreIntrinEmitter:
...
@@ -761,8 +723,7 @@ class TensorCoreIntrinEmitter:
from
tilelang.utils
import
is_fragment
from
tilelang.utils
import
is_fragment
shape
=
local_buf
.
shape
shape
=
local_buf
.
shape
assert
is_fragment
(
assert
is_fragment
(
local_buf
),
f
"local_buf
{
local_buf
}
must be a fragment, but got
{
local_buf
.
scope
()
}
"
local_buf
),
f
"local_buf
{
local_buf
}
must be a fragment, but got
{
local_buf
.
scope
()
}
"
inverse_mma_store_layout
=
self
.
get_store_index_map
(
inverse
=
True
)
inverse_mma_store_layout
=
self
.
get_store_index_map
(
inverse
=
True
)
micro_size_x
,
micro_size_y
=
self
.
micro_size_x
,
self
.
micro_size_y
micro_size_x
,
micro_size_y
=
self
.
micro_size_x
,
self
.
micro_size_y
...
@@ -954,10 +915,12 @@ class TensorCoreIntrinEmitterWithLadderTransform(TensorCoreIntrinEmitter):
...
@@ -954,10 +915,12 @@ class TensorCoreIntrinEmitterWithLadderTransform(TensorCoreIntrinEmitter):
".b16"
,
".b16"
,
A_local_buf
.
data
,
A_local_buf
.
data
,
i
*
local_size_a
,
i
*
local_size_a
,
T
.
address_of
(
A_shared_buf
[
T
.
address_of
(
warp_m
*
warp_row_tiles
+
i
*
micro_size_x
,
A_shared_buf
[
rk
*
chunk
+
ki
*
micro_size_k
,
warp_m
*
warp_row_tiles
+
i
*
micro_size_x
,
]),
rk
*
chunk
+
ki
*
micro_size_k
,
]
),
get_ldmatrix_offset
(
"A"
,
tx
,
0
,
stride
,
a_dtype
,
a_transposed
),
get_ldmatrix_offset
(
"A"
,
tx
,
0
,
stride
,
a_dtype
,
a_transposed
),
)
)
elif
transform_kind_a
==
TransformKind
.
InterWarpTransform
:
elif
transform_kind_a
==
TransformKind
.
InterWarpTransform
:
...
@@ -1019,10 +982,8 @@ class TensorCoreIntrinEmitterWithLadderTransform(TensorCoreIntrinEmitter):
...
@@ -1019,10 +982,8 @@ class TensorCoreIntrinEmitterWithLadderTransform(TensorCoreIntrinEmitter):
warp_m
*
warp_rows
+
j
,
warp_m
*
warp_rows
+
j
,
rk
*
(
chunk
//
micro_size_k
)
+
ki
,
rk
*
(
chunk
//
micro_size_k
)
+
ki
,
)
)
rii
,
rjj
=
(
tx
*
local_size_a
+
rii
,
rjj
=
(
tx
*
local_size_a
+
local_id
)
//
micro_size_k
,
(
tx
*
local_size_a
+
local_id
)
%
(
micro_size_k
)
local_id
)
//
micro_size_k
,
(
tx
*
local_size_a
+
local_id
)
%
(
A_local_buf
[
j
*
local_size_a
+
local_id
]
=
A_shared_buf
[
ri
,
rj
,
rii
,
rjj
]
micro_size_k
)
A_local_buf
[
j
*
local_size_a
+
local_id
]
=
(
A_shared_buf
[
ri
,
rj
,
rii
,
rjj
])
else
:
else
:
raise
ValueError
(
"Unsupported TransformKind for Input A"
)
raise
ValueError
(
"Unsupported TransformKind for Input A"
)
...
@@ -1131,12 +1092,11 @@ class TensorCoreIntrinEmitterWithLadderTransform(TensorCoreIntrinEmitter):
...
@@ -1131,12 +1092,11 @@ class TensorCoreIntrinEmitterWithLadderTransform(TensorCoreIntrinEmitter):
warp_n
*
warp_cols
+
j
,
warp_n
*
warp_cols
+
j
,
rk
*
(
chunk
//
micro_size_k
)
+
ki
,
rk
*
(
chunk
//
micro_size_k
)
+
ki
,
)
)
rii
,
rjj
=
(
tx
*
local_size_dequantize
+
rii
,
rjj
=
(
local_id
)
//
(
micro_size_k
//
num_elems_per_byte
),
(
(
tx
*
local_size_dequantize
+
local_id
)
//
(
micro_size_k
//
num_elems_per_byte
),
tx
*
local_size_dequantize
+
local_id
)
%
(
(
tx
*
local_size_dequantize
+
local_id
)
%
(
micro_size_k
//
num_elems_per_byte
),
micro_size_k
//
num_elems_per_byte
)
)
B_local_buf
[
j
*
local_size_dequantize
+
local_id
]
=
(
B_local_buf
[
j
*
local_size_dequantize
+
local_id
]
=
B_shared_buf
[
ri
,
rj
,
rii
,
rjj
]
B_shared_buf
[
ri
,
rj
,
rii
,
rjj
])
else
:
else
:
raise
ValueError
(
"Unsupported TransformKind for Input B"
)
raise
ValueError
(
"Unsupported TransformKind for Input B"
)
...
@@ -1195,7 +1155,6 @@ class TensorCoreIntrinEmitterWithLadderTransform(TensorCoreIntrinEmitter):
...
@@ -1195,7 +1155,6 @@ class TensorCoreIntrinEmitterWithLadderTransform(TensorCoreIntrinEmitter):
class
INT4TensorCoreIntrinEmitter
(
TensorCoreIntrinEmitter
):
class
INT4TensorCoreIntrinEmitter
(
TensorCoreIntrinEmitter
):
def
mma
(
self
,
A_local_buf
,
B_local_buf
,
C_local_buf
):
def
mma
(
self
,
A_local_buf
,
B_local_buf
,
C_local_buf
):
warp_rows
=
self
.
warp_rows
warp_rows
=
self
.
warp_rows
warp_cols
=
self
.
warp_cols
warp_cols
=
self
.
warp_cols
...
@@ -1298,9 +1257,7 @@ class INT4TensorCoreIntrinEmitter(TensorCoreIntrinEmitter):
...
@@ -1298,9 +1257,7 @@ class INT4TensorCoreIntrinEmitter(TensorCoreIntrinEmitter):
class
INT4TensorCoreIntrinEmitterWithLadderTransform
(
TensorCoreIntrinEmitterWithLadderTransform
):
class
INT4TensorCoreIntrinEmitterWithLadderTransform
(
TensorCoreIntrinEmitterWithLadderTransform
):
def
mma
(
self
,
A_local_buf
,
B_local_buf
,
C_local_buf
):
def
mma
(
self
,
A_local_buf
,
B_local_buf
,
C_local_buf
):
warp_rows
=
self
.
warp_rows
warp_rows
=
self
.
warp_rows
warp_cols
=
self
.
warp_cols
warp_cols
=
self
.
warp_cols
local_size_a
=
self
.
local_size_a
local_size_a
=
self
.
local_size_a
...
...
tilelang/intrinsics/mma_sm70_layout.py
View file @
29051439
...
@@ -17,10 +17,8 @@ def shared_16x4_to_mma_b_32x4_layout_trans(row, col, rep):
...
@@ -17,10 +17,8 @@ def shared_16x4_to_mma_b_32x4_layout_trans(row, col, rep):
def
mma_32x8_to_shared_16x16_layout_fp32
(
thread_id
,
local_id
):
def
mma_32x8_to_shared_16x16_layout_fp32
(
thread_id
,
local_id
):
row
=
(
thread_id
%
2
)
+
(
row
=
(
thread_id
%
2
)
+
((
local_id
//
2
%
2
)
*
2
)
+
4
*
(
thread_id
//
16
)
+
(
thread_id
%
16
//
4
)
%
2
*
8
(
local_id
//
2
%
2
)
*
2
)
+
4
*
(
thread_id
//
16
)
+
(
thread_id
%
16
//
4
)
%
2
*
8
col
=
(
thread_id
%
4
//
2
)
*
2
+
(
thread_id
%
16
//
8
)
*
4
+
(
local_id
%
2
)
+
(
local_id
//
4
)
*
8
col
=
(
thread_id
%
4
//
2
)
*
2
+
(
thread_id
%
16
//
8
)
*
4
+
(
local_id
%
2
)
+
(
local_id
//
4
)
*
8
return
row
,
col
return
row
,
col
...
@@ -31,7 +29,7 @@ def mma_32x8_to_shared_16x16_layout_fp16(thread_id, local_id):
...
@@ -31,7 +29,7 @@ def mma_32x8_to_shared_16x16_layout_fp16(thread_id, local_id):
def
mma_load_a_32x4_to_shared_16x4_layout
(
thread_id
,
local_id
):
def
mma_load_a_32x4_to_shared_16x4_layout
(
thread_id
,
local_id
):
row
=
(
thread_id
%
4
)
+
(
4
*
((
(
thread_id
//
16
+
thread_id
%
16
//
4
*
2
)
)
%
4
))
row
=
(
thread_id
%
4
)
+
(
4
*
((
thread_id
//
16
+
thread_id
%
16
//
4
*
2
)
%
4
))
col
=
local_id
col
=
local_id
return
row
,
col
return
row
,
col
...
...
tilelang/intrinsics/mma_sm70_macro_generator.py
View file @
29051439
...
@@ -147,18 +147,15 @@ class TensorCoreIntrinEmitter:
...
@@ -147,18 +147,15 @@ class TensorCoreIntrinEmitter:
def
get_store_index_map
(
self
,
inverse
:
bool
=
False
)
->
IndexMap
:
def
get_store_index_map
(
self
,
inverse
:
bool
=
False
)
->
IndexMap
:
warp_size
,
local_size_c
=
self
.
WARP_SIZE
,
self
.
local_size_out
warp_size
,
local_size_c
=
self
.
WARP_SIZE
,
self
.
local_size_out
index_map
=
IndexMap
.
from_func
(
index_map
=
IndexMap
.
from_func
(
mma_32x8_to_shared_16x16_layout_fp32
mma_32x8_to_shared_16x16_layout_fp32
if
self
.
accum_dtype
==
"float32"
else
mma_32x8_to_shared_16x16_layout_fp16
,
i
f
self
.
accum_dtype
==
"float32"
else
mma_32x8_to_shared_16x16_layout_fp16
,
i
ndex_dtype
=
"int32"
,
index_dtype
=
"int32"
)
)
if
not
inverse
:
if
not
inverse
:
return
index_map
return
index_map
inverse_index_map
=
index_map
.
inverse
([
warp_size
,
local_size_c
])
inverse_index_map
=
index_map
.
inverse
([
warp_size
,
local_size_c
])
return
inverse_index_map
return
inverse_index_map
def
extract_thread_binding
(
def
extract_thread_binding
(
self
,
thread_id
:
PrimExpr
,
is_m_first
:
bool
|
None
=
None
)
->
tuple
[
PrimExpr
,
PrimExpr
,
PrimExpr
]:
self
,
thread_id
:
PrimExpr
,
is_m_first
:
bool
|
None
=
None
)
->
tuple
[
PrimExpr
,
PrimExpr
,
PrimExpr
]:
"""
"""
is_m_first: True if the thread binding is in the form of (tx, warp_n, warp_m)
is_m_first: True if the thread binding is in the form of (tx, warp_n, warp_m)
which represents [warp_size, block_row_warps (split n), block_col_warps (split m)]
which represents [warp_size, block_row_warps (split n), block_col_warps (split m)]
...
@@ -187,11 +184,7 @@ class TensorCoreIntrinEmitter:
...
@@ -187,11 +184,7 @@ class TensorCoreIntrinEmitter:
)
)
return
lane_id
,
warp_n
,
warp_m
return
lane_id
,
warp_n
,
warp_m
def
ldmatrix_a
(
self
,
def
ldmatrix_a
(
self
,
A_local_buf
:
Buffer
,
A_shared_buf
:
Buffer
|
BufferRegion
,
ki
:
PrimExpr
,
rk
:
PrimExpr
|
None
=
0
):
A_local_buf
:
Buffer
,
A_shared_buf
:
Buffer
|
BufferRegion
,
ki
:
PrimExpr
,
rk
:
PrimExpr
|
None
=
0
):
warp_row_tiles
=
self
.
warp_row_tiles
warp_row_tiles
=
self
.
warp_row_tiles
warp_rows
=
self
.
warp_rows
warp_rows
=
self
.
warp_rows
chunk
=
self
.
chunk
chunk
=
self
.
chunk
...
@@ -231,11 +224,7 @@ class TensorCoreIntrinEmitter:
...
@@ -231,11 +224,7 @@ class TensorCoreIntrinEmitter:
return
_warp_ldmatrix_a
(
A_local_buf
,
A_region
,
ki
,
thread_binding
,
rk
)
return
_warp_ldmatrix_a
(
A_local_buf
,
A_region
,
ki
,
thread_binding
,
rk
)
def
ldmatrix_b
(
self
,
def
ldmatrix_b
(
self
,
B_local_buf
:
Buffer
,
B_shared_buf
:
Buffer
|
BufferRegion
,
ki
:
PrimExpr
,
rk
:
PrimExpr
|
None
=
0
):
B_local_buf
:
Buffer
,
B_shared_buf
:
Buffer
|
BufferRegion
,
ki
:
PrimExpr
,
rk
:
PrimExpr
|
None
=
0
):
warp_col_tiles
=
self
.
warp_col_tiles
warp_col_tiles
=
self
.
warp_col_tiles
warp_cols
=
self
.
warp_cols
warp_cols
=
self
.
warp_cols
chunk
=
self
.
chunk
chunk
=
self
.
chunk
...
@@ -274,20 +263,14 @@ class TensorCoreIntrinEmitter:
...
@@ -274,20 +263,14 @@ class TensorCoreIntrinEmitter:
for
j
in
T
.
vectorized
(
local_size_b
):
for
j
in
T
.
vectorized
(
local_size_b
):
if
b_transposed
:
if
b_transposed
:
mi
,
mk
=
mma_load_layout
(
tx
,
j
)
mi
,
mk
=
mma_load_layout
(
tx
,
j
)
B_local_buf
[
i
*
local_size_b
+
j
]
=
B_buf
[
B_base0
+
wi
+
mi
,
B_local_buf
[
i
*
local_size_b
+
j
]
=
B_buf
[
B_base0
+
wi
+
mi
,
B_base1
+
wk
+
mk
]
B_base1
+
wk
+
mk
]
else
:
else
:
mk
,
mi
=
mma_load_layout
(
tx
,
j
)
mk
,
mi
=
mma_load_layout
(
tx
,
j
)
B_local_buf
[
i
*
local_size_b
+
j
]
=
B_buf
[
B_base0
+
wk
+
mk
,
B_local_buf
[
i
*
local_size_b
+
j
]
=
B_buf
[
B_base0
+
wk
+
mk
,
B_base1
+
wi
+
mi
]
B_base1
+
wi
+
mi
]
return
_warp_ldmatrix_b
(
B_local_buf
,
B_region
,
ki
,
thread_binding
,
rk
)
return
_warp_ldmatrix_b
(
B_local_buf
,
B_region
,
ki
,
thread_binding
,
rk
)
def
mma
(
self
,
def
mma
(
self
,
A_local_buf
:
Buffer
,
B_local_buf
:
Buffer
,
C_local_buf
:
Buffer
,
k_inner
:
PrimExpr
|
None
=
0
):
A_local_buf
:
Buffer
,
B_local_buf
:
Buffer
,
C_local_buf
:
Buffer
,
k_inner
:
PrimExpr
|
None
=
0
):
warp_rows
=
self
.
warp_rows
warp_rows
=
self
.
warp_rows
warp_cols
=
self
.
warp_cols
warp_cols
=
self
.
warp_cols
local_size_a
=
self
.
local_size_a
local_size_a
=
self
.
local_size_a
...
@@ -326,9 +309,7 @@ class TensorCoreIntrinEmitter:
...
@@ -326,9 +309,7 @@ class TensorCoreIntrinEmitter:
return
_warp_mma
(
A_local_buf
,
B_local_buf
,
C_local_buf
)
return
_warp_mma
(
A_local_buf
,
B_local_buf
,
C_local_buf
)
def
make_mma_load_layout
(
self
,
def
make_mma_load_layout
(
self
,
local_buf
:
Buffer
,
matrix
:
Literal
[
"A"
,
"B"
]
=
"A"
)
->
T
.
Fragment
:
local_buf
:
Buffer
,
matrix
:
Literal
[
"A"
,
"B"
]
=
"A"
)
->
T
.
Fragment
:
"""
"""
Create a layout function for storing MMA results into a fragment buffer.
Create a layout function for storing MMA results into a fragment buffer.
This layout is used in conjunction with `inverse_mma_store_layout` to
This layout is used in conjunction with `inverse_mma_store_layout` to
...
@@ -351,6 +332,7 @@ class TensorCoreIntrinEmitter:
...
@@ -351,6 +332,7 @@ class TensorCoreIntrinEmitter:
If `local_buf` is not detected to be a fragment buffer.
If `local_buf` is not detected to be a fragment buffer.
"""
"""
from
tilelang.utils
import
is_fragment
from
tilelang.utils
import
is_fragment
assert
matrix
in
[
"A"
,
"B"
],
"matrix should be either A or B"
assert
matrix
in
[
"A"
,
"B"
],
"matrix should be either A or B"
matrix_is_a
:
bool
=
matrix
==
"A"
matrix_is_a
:
bool
=
matrix
==
"A"
matrix_is_b
:
bool
=
matrix
==
"B"
matrix_is_b
:
bool
=
matrix
==
"B"
...
@@ -383,11 +365,9 @@ class TensorCoreIntrinEmitter:
...
@@ -383,11 +365,9 @@ class TensorCoreIntrinEmitter:
# so the b matrix expected a transposed basic layout
# so the b matrix expected a transposed basic layout
transform_func
:
Callable
=
None
transform_func
:
Callable
=
None
if
matrix_is_a
:
if
matrix_is_a
:
transform_func
=
transform_func_sr_a
if
is_sr_axis_order
else
lambda
i
,
j
:
transform_func_sr_a
(
transform_func
=
transform_func_sr_a
if
is_sr_axis_order
else
lambda
i
,
j
:
transform_func_sr_a
(
j
,
i
)
j
,
i
)
elif
matrix_is_b
:
elif
matrix_is_b
:
transform_func
=
transform_func_sr_b
if
is_sr_axis_order
else
lambda
i
,
j
:
transform_func_rs_b
(
transform_func
=
transform_func_sr_b
if
is_sr_axis_order
else
lambda
i
,
j
:
transform_func_rs_b
(
i
,
j
)
i
,
j
)
else
:
else
:
raise
ValueError
(
f
"Unsupported matrix
{
matrix
}
"
)
raise
ValueError
(
f
"Unsupported matrix
{
matrix
}
"
)
...
@@ -413,9 +393,8 @@ class TensorCoreIntrinEmitter:
...
@@ -413,9 +393,8 @@ class TensorCoreIntrinEmitter:
return
lane_id
,
local_id
return
lane_id
,
local_id
base_fragment
=
T
.
Fragment
(
base_fragment
=
T
.
Fragment
(
[
micro_size_s
,
micro_size_r
]
if
is_sr_axis_order
else
[
micro_size_r
,
micro_size_s
],
[
micro_size_s
,
micro_size_r
]
if
is_sr_axis_order
else
[
micro_size_r
,
micro_size_s
],
forward_fn
=
forward
,
replicate
=
2
forward_fn
=
forward
,
)
replicate
=
2
)
warp_rows
,
warp_cols
=
self
.
warp_rows
,
self
.
warp_cols
warp_rows
,
warp_cols
=
self
.
warp_rows
,
self
.
warp_cols
chunk
=
self
.
chunk
chunk
=
self
.
chunk
...
@@ -426,31 +405,19 @@ class TensorCoreIntrinEmitter:
...
@@ -426,31 +405,19 @@ class TensorCoreIntrinEmitter:
replicate
=
block_col_warps
if
matrix_is_a
else
block_row_warps
replicate
=
block_col_warps
if
matrix_is_a
else
block_row_warps
if
is_sr_axis_order
:
if
is_sr_axis_order
:
warp_fragment
=
base_fragment
.
repeat
([
warp_s
,
warp_r
],
warp_fragment
=
base_fragment
.
repeat
([
warp_s
,
warp_r
],
repeat_on_thread
=
False
,
lower_dim_first
=
False
)
repeat_on_thread
=
False
,
lower_dim_first
=
False
)
if
matrix_is_a
:
if
matrix_is_a
:
block_fragment
=
warp_fragment
.
repeat
([
block_s
,
1
],
block_fragment
=
warp_fragment
.
repeat
([
block_s
,
1
],
repeat_on_thread
=
True
,
lower_dim_first
=
True
).
replicate
(
replicate
)
repeat_on_thread
=
True
,
lower_dim_first
=
True
).
replicate
(
replicate
)
elif
matrix_is_b
:
elif
matrix_is_b
:
block_fragment
=
warp_fragment
.
replicate
(
replicate
).
repeat
([
block_s
,
1
],
block_fragment
=
warp_fragment
.
replicate
(
replicate
).
repeat
([
block_s
,
1
],
repeat_on_thread
=
True
,
lower_dim_first
=
True
)
repeat_on_thread
=
True
,
lower_dim_first
=
True
)
else
:
else
:
raise
ValueError
(
f
"Unsupported matrix type
{
matrix
}
"
)
raise
ValueError
(
f
"Unsupported matrix type
{
matrix
}
"
)
else
:
else
:
warp_fragment
=
base_fragment
.
repeat
([
warp_r
,
warp_s
],
warp_fragment
=
base_fragment
.
repeat
([
warp_r
,
warp_s
],
repeat_on_thread
=
False
,
lower_dim_first
=
True
)
repeat_on_thread
=
False
,
lower_dim_first
=
True
)
if
matrix_is_a
:
if
matrix_is_a
:
block_fragment
=
warp_fragment
.
repeat
([
1
,
block_s
],
block_fragment
=
warp_fragment
.
repeat
([
1
,
block_s
],
repeat_on_thread
=
True
,
lower_dim_first
=
True
).
replicate
(
replicate
)
repeat_on_thread
=
True
,
lower_dim_first
=
True
).
replicate
(
replicate
)
elif
matrix_is_b
:
elif
matrix_is_b
:
block_fragment
=
warp_fragment
.
replicate
(
replicate
).
repeat
([
1
,
block_s
],
block_fragment
=
warp_fragment
.
replicate
(
replicate
).
repeat
([
1
,
block_s
],
repeat_on_thread
=
True
,
lower_dim_first
=
True
)
repeat_on_thread
=
True
,
lower_dim_first
=
True
)
else
:
else
:
raise
ValueError
(
f
"Unsupported matrix type
{
matrix
}
"
)
raise
ValueError
(
f
"Unsupported matrix type
{
matrix
}
"
)
...
...
tilelang/intrinsics/mma_sp_layout.py
View file @
29051439
...
@@ -72,56 +72,47 @@ def get_logical_id_32bit(thread_id: int) -> int:
...
@@ -72,56 +72,47 @@ def get_logical_id_32bit(thread_id: int) -> int:
return
(
thread_id
//
4
)
*
2
+
(
thread_id
%
4
)
%
2
return
(
thread_id
//
4
)
*
2
+
(
thread_id
%
4
)
%
2
def
metadata_8bit_load_32x4_to_shared_16x4_layout_32bit
(
thread_id
:
int
,
def
metadata_8bit_load_32x4_to_shared_16x4_layout_32bit
(
thread_id
:
int
,
local_id
:
int
)
->
tuple
[
int
,
int
]:
local_id
:
int
)
->
tuple
[
int
,
int
]:
logical_id
=
get_logical_id_32bit
(
thread_id
)
logical_id
=
get_logical_id_32bit
(
thread_id
)
row
=
logical_id
//
4
+
local_id
*
8
row
=
logical_id
//
4
+
local_id
*
8
col
=
logical_id
%
4
col
=
logical_id
%
4
return
row
,
col
return
row
,
col
def
metadata_16bit_load_32x2_to_shared_16x2_layout_32bit
(
thread_id
:
int
,
def
metadata_16bit_load_32x2_to_shared_16x2_layout_32bit
(
thread_id
:
int
,
local_id
:
int
)
->
tuple
[
int
,
int
]:
local_id
:
int
)
->
tuple
[
int
,
int
]:
logical_id
=
get_logical_id_32bit
(
thread_id
)
logical_id
=
get_logical_id_32bit
(
thread_id
)
row
=
logical_id
//
2
+
local_id
*
8
row
=
logical_id
//
2
+
local_id
*
8
col
=
logical_id
%
2
col
=
logical_id
%
2
return
row
,
col
return
row
,
col
def
metadata_8bit_load_32x4_to_shared_16x4_layout_16bit
(
thread_id
:
int
,
def
metadata_8bit_load_32x4_to_shared_16x4_layout_16bit
(
thread_id
:
int
,
local_id
:
int
)
->
tuple
[
int
,
int
]:
local_id
:
int
)
->
tuple
[
int
,
int
]:
return
metadata_8bit_load_32x4_to_shared_16x4_layout_32bit
(
thread_id
,
local_id
)
# same mapping for 16bit and 32bit
return
metadata_8bit_load_32x4_to_shared_16x4_layout_32bit
(
thread_id
,
local_id
)
# same mapping for 16bit and 32bit
def
metadata_16bit_load_32x2_to_shared_16x2_layout_16bit
(
thread_id
:
int
,
def
metadata_16bit_load_32x2_to_shared_16x2_layout_16bit
(
thread_id
:
int
,
local_id
:
int
)
->
tuple
[
int
,
int
]:
local_id
:
int
)
->
tuple
[
int
,
int
]:
return
metadata_16bit_load_32x2_to_shared_16x2_layout_32bit
(
thread_id
,
local_id
)
# same mapping for 16bit and 32bit
return
metadata_16bit_load_32x2_to_shared_16x2_layout_32bit
(
thread_id
,
local_id
)
# same mapping for 16bit and 32bit
def
get_logical_id_8bit
(
thread_id
:
int
)
->
int
:
def
get_logical_id_8bit
(
thread_id
:
int
)
->
int
:
return
thread_id
return
thread_id
def
metadata_8bit_load_32x4_to_shared_16x4_layout_8bit
(
thread_id
:
int
,
def
metadata_8bit_load_32x4_to_shared_16x4_layout_8bit
(
thread_id
:
int
,
local_id
:
int
)
->
tuple
[
int
,
int
]:
local_id
:
int
)
->
tuple
[
int
,
int
]:
logical_id
=
get_logical_id_8bit
(
thread_id
)
logical_id
=
get_logical_id_8bit
(
thread_id
)
row
=
logical_id
//
2
+
local_id
*
8
row
=
logical_id
//
2
+
local_id
*
8
col
=
(
logical_id
%
4
)
//
2
*
4
+
local_id
col
=
(
logical_id
%
4
)
//
2
*
4
+
local_id
return
row
,
col
return
row
,
col
def
metadata_16bit_load_32x2_to_shared_16x4_layout_8bit
(
thread_id
:
int
,
def
metadata_16bit_load_32x2_to_shared_16x4_layout_8bit
(
thread_id
:
int
,
local_id
:
int
)
->
tuple
[
int
,
int
]:
local_id
:
int
)
->
tuple
[
int
,
int
]:
logical_id
=
get_logical_id_8bit
(
thread_id
)
logical_id
=
get_logical_id_8bit
(
thread_id
)
row
=
logical_id
//
2
+
local_id
*
8
row
=
logical_id
//
2
+
local_id
*
8
col
=
(
logical_id
%
4
)
//
2
*
2
+
local_id
col
=
(
logical_id
%
4
)
//
2
*
2
+
local_id
return
row
,
col
return
row
,
col
def
metadata_32bit_load_32x1_to_shared_16x2_layout_8bit
(
thread_id
:
int
,
def
metadata_32bit_load_32x1_to_shared_16x2_layout_8bit
(
thread_id
:
int
,
local_id
:
int
)
->
tuple
[
int
,
int
]:
local_id
:
int
)
->
tuple
[
int
,
int
]:
# local_id is always 0
# local_id is always 0
logical_id
=
get_logical_id_8bit
(
thread_id
)
logical_id
=
get_logical_id_8bit
(
thread_id
)
row
=
logical_id
//
4
+
(
logical_id
%
2
)
*
8
row
=
logical_id
//
4
+
(
logical_id
%
2
)
*
8
...
...
tilelang/intrinsics/mma_sp_macro_generator.py
View file @
29051439
...
@@ -190,8 +190,7 @@ class SparseTensorCoreIntrinEmitter:
...
@@ -190,8 +190,7 @@ class SparseTensorCoreIntrinEmitter:
def
_initialize_local_size
(
self
,
m_dim
=
16
,
n_dim
=
16
,
k_dim
=
16
,
warp_size
=
32
):
def
_initialize_local_size
(
self
,
m_dim
=
16
,
n_dim
=
16
,
k_dim
=
16
,
warp_size
=
32
):
self
.
local_size_a
=
(
m_dim
*
k_dim
)
//
warp_size
//
self
.
SPARSE_FACTOR
self
.
local_size_a
=
(
m_dim
*
k_dim
)
//
warp_size
//
self
.
SPARSE_FACTOR
self
.
local_size_e
=
(
self
.
local_size_e
=
(
m_dim
*
k_dim
)
//
self
.
e_factor
//
warp_size
*
self
.
E_REPLICATE_FACTOR
[
self
.
a_dtype
]
m_dim
*
k_dim
)
//
self
.
e_factor
//
warp_size
*
self
.
E_REPLICATE_FACTOR
[
self
.
a_dtype
]
self
.
local_size_b
=
(
n_dim
*
k_dim
)
//
warp_size
self
.
local_size_b
=
(
n_dim
*
k_dim
)
//
warp_size
self
.
local_size_out
=
(
m_dim
*
n_dim
)
//
warp_size
self
.
local_size_out
=
(
m_dim
*
n_dim
)
//
warp_size
...
@@ -257,10 +256,7 @@ class SparseTensorCoreIntrinEmitter:
...
@@ -257,10 +256,7 @@ class SparseTensorCoreIntrinEmitter:
inverse_index_map
=
index_map
.
inverse
([
warp_size
,
local_size_c
])
inverse_index_map
=
index_map
.
inverse
([
warp_size
,
local_size_c
])
return
inverse_index_map
return
inverse_index_map
def
extract_thread_binding
(
def
extract_thread_binding
(
self
,
thread_id
:
PrimExpr
,
is_m_first
:
bool
|
None
=
None
)
->
tuple
[
PrimExpr
,
PrimExpr
,
PrimExpr
]:
self
,
thread_id
:
PrimExpr
,
is_m_first
:
bool
|
None
=
None
)
->
tuple
[
PrimExpr
,
PrimExpr
,
PrimExpr
]:
"""
"""
is_m_first: True if the thread binding is in the form of (tx, warp_n, warp_m)
is_m_first: True if the thread binding is in the form of (tx, warp_n, warp_m)
which represents [warp_size, block_row_warps (split n), block_col_warps (split m)]
which represents [warp_size, block_row_warps (split n), block_col_warps (split m)]
...
@@ -330,8 +326,7 @@ class SparseTensorCoreIntrinEmitter:
...
@@ -330,8 +326,7 @@ class SparseTensorCoreIntrinEmitter:
for
i
in
T
.
serial
(
warp_rows
):
for
i
in
T
.
serial
(
warp_rows
):
# Assign A_shared_buf_elem
# Assign A_shared_buf_elem
wi
,
wk
=
warp_m
*
warp_row_tiles
+
i
*
micro_size_x
,
(
wi
,
wk
=
warp_m
*
warp_row_tiles
+
i
*
micro_size_x
,
(
rk
*
warp_k
+
ki
*
micro_size_k
)
//
self
.
SPARSE_FACTOR
rk
*
warp_k
+
ki
*
micro_size_k
)
//
self
.
SPARSE_FACTOR
A_shared_buf_elem
=
A_shared_buf
[
wk
,
wi
]
if
a_transposed
else
A_shared_buf
[
wi
,
wk
]
A_shared_buf_elem
=
A_shared_buf
[
wk
,
wi
]
if
a_transposed
else
A_shared_buf
[
wi
,
wk
]
if
ldmatrix_available
:
if
ldmatrix_available
:
...
@@ -348,10 +343,9 @@ class SparseTensorCoreIntrinEmitter:
...
@@ -348,10 +343,9 @@ class SparseTensorCoreIntrinEmitter:
else
:
else
:
for
j
in
T
.
serial
(
local_size_a
):
for
j
in
T
.
serial
(
local_size_a
):
mi
,
mk
=
mma_load_layout
(
tx
,
j
)
mi
,
mk
=
mma_load_layout
(
tx
,
j
)
A_local_buf
[
i
*
local_size_a
+
A_local_buf
[
i
*
local_size_a
+
j
]
=
(
j
]
=
A_shared_buf
[
wk
+
mk
,
wi
+
A_shared_buf
[
wk
+
mk
,
wi
+
mi
]
if
a_transposed
else
A_shared_buf
[
wi
+
mi
,
wk
+
mk
]
mi
]
if
a_transposed
else
A_shared_buf
[
wi
+
mi
,
)
wk
+
mk
]
return
_warp_ldmatrix_a
(
A_local_buf
,
A_shared_buf
,
ki
,
thread_binding
,
rk
)
return
_warp_ldmatrix_a
(
A_local_buf
,
A_shared_buf
,
ki
,
thread_binding
,
rk
)
...
@@ -412,14 +406,10 @@ class SparseTensorCoreIntrinEmitter:
...
@@ -412,14 +406,10 @@ class SparseTensorCoreIntrinEmitter:
tx
,
_
,
warp_m
=
self
.
extract_thread_binding
(
thread_binding
)
tx
,
_
,
warp_m
=
self
.
extract_thread_binding
(
thread_binding
)
for
i
in
T
.
serial
(
warp_rows
):
for
i
in
T
.
serial
(
warp_rows
):
# Assign E_shared_buf_elem
# Assign E_shared_buf_elem
wi
,
wk
=
warp_m
*
warp_row_tiles
+
i
*
micro_size_x
,
(
wi
,
wk
=
warp_m
*
warp_row_tiles
+
i
*
micro_size_x
,
(
rk
*
warp_k
+
ki
*
micro_size_k
)
//
self
.
e_factor
rk
*
warp_k
+
ki
*
micro_size_k
)
//
self
.
e_factor
for
j
in
T
.
serial
(
local_size_e
):
for
j
in
T
.
serial
(
local_size_e
):
mi
,
mk
=
mma_load_layout
(
tx
,
j
)
mi
,
mk
=
mma_load_layout
(
tx
,
j
)
E_local_buf
[
i
*
local_size_e
+
E_local_buf
[
i
*
local_size_e
+
j
]
=
E_shared_buf
[
wk
+
mk
,
wi
+
mi
]
if
trans
else
E_shared_buf
[
wi
+
mi
,
wk
+
mk
]
j
]
=
E_shared_buf
[
wk
+
mk
,
wi
+
mi
]
if
trans
else
E_shared_buf
[
wi
+
mi
,
wk
+
mk
]
return
_warp_ldmatrix_e
(
E_local_buf
,
E_shared_buf
,
ki
,
thread_binding
,
rk
)
return
_warp_ldmatrix_e
(
E_local_buf
,
E_shared_buf
,
ki
,
thread_binding
,
rk
)
...
@@ -433,7 +423,7 @@ class SparseTensorCoreIntrinEmitter:
...
@@ -433,7 +423,7 @@ class SparseTensorCoreIntrinEmitter:
b_dtype
=
self
.
b_dtype
b_dtype
=
self
.
b_dtype
b_transposed
=
self
.
b_transposed
b_transposed
=
self
.
b_transposed
thread_binding
=
self
.
get_thread_binding
()
thread_binding
=
self
.
get_thread_binding
()
replicate_b
=
(
self
.
n_dim
==
16
)
replicate_b
=
self
.
n_dim
==
16
# ldmatrix cannot be used for int8 + trans case.
# ldmatrix cannot be used for int8 + trans case.
ldmatrix_available
=
not
(
DataType
(
b_dtype
).
bits
!=
16
and
not
b_transposed
)
ldmatrix_available
=
not
(
DataType
(
b_dtype
).
bits
!=
16
and
not
b_transposed
)
...
@@ -470,8 +460,7 @@ class SparseTensorCoreIntrinEmitter:
...
@@ -470,8 +460,7 @@ class SparseTensorCoreIntrinEmitter:
)
)
if
ldmatrix_available
:
if
ldmatrix_available
:
B_shared_buf_elem
=
B_shared_buf
[
wi
,
wk
]
if
b_transposed
else
B_shared_buf
[
wk
,
B_shared_buf_elem
=
B_shared_buf
[
wi
,
wk
]
if
b_transposed
else
B_shared_buf
[
wk
,
wi
]
wi
]
if
replicate_b
:
if
replicate_b
:
T
.
ptx_ldmatrix
(
T
.
ptx_ldmatrix
(
...
@@ -493,9 +482,7 @@ class SparseTensorCoreIntrinEmitter:
...
@@ -493,9 +482,7 @@ class SparseTensorCoreIntrinEmitter:
B_local_buf
.
data
,
B_local_buf
.
data
,
i
*
local_size_b
+
lift
(
local_size_b
)
//
2
,
i
*
local_size_b
+
lift
(
local_size_b
)
//
2
,
T
.
address_of
(
B_shared_buf_elem
),
T
.
address_of
(
B_shared_buf_elem
),
get_ldmatrix_offset_b
(
"B"
,
tx
,
get_ldmatrix_offset_b
(
"B"
,
tx
,
lift
(
local_size_b
)
//
2
,
stride
,
b_dtype
,
b_transposed
),
lift
(
local_size_b
)
//
2
,
stride
,
b_dtype
,
b_transposed
),
)
)
else
:
else
:
T
.
ptx_ldmatrix
(
T
.
ptx_ldmatrix
(
...
@@ -514,19 +501,13 @@ class SparseTensorCoreIntrinEmitter:
...
@@ -514,19 +501,13 @@ class SparseTensorCoreIntrinEmitter:
# must be transposed.
# must be transposed.
for
j
in
T
.
serial
(
local_size_b
):
for
j
in
T
.
serial
(
local_size_b
):
mi
,
mk
=
mma_load_layout
(
tx
,
j
)
mi
,
mk
=
mma_load_layout
(
tx
,
j
)
B_local_buf
[
i
*
local_size_b
+
B_local_buf
[
i
*
local_size_b
+
j
]
=
(
j
]
=
B_shared_buf
[
wi
+
mi
,
wk
+
B_shared_buf
[
wi
+
mi
,
wk
+
mk
]
if
b_transposed
else
B_shared_buf
[
wk
+
mk
,
wi
+
mi
]
mk
]
if
b_transposed
else
B_shared_buf
[
wk
+
mk
,
)
wi
+
mi
]
return
_warp_ldmatrix_b
(
B_local_buf
,
B_shared_buf
,
ki
,
thread_binding
,
rk
)
return
_warp_ldmatrix_b
(
B_local_buf
,
B_shared_buf
,
ki
,
thread_binding
,
rk
)
def
mma_sp
(
self
,
def
mma_sp
(
self
,
A_local_buf
:
Buffer
,
E_local_buf
:
Buffer
,
B_local_buf
:
Buffer
,
C_local_buf
:
Buffer
,
k_inner
:
PrimExpr
=
0
):
A_local_buf
:
Buffer
,
E_local_buf
:
Buffer
,
B_local_buf
:
Buffer
,
C_local_buf
:
Buffer
,
k_inner
:
PrimExpr
=
0
):
warp_rows
=
self
.
warp_rows
warp_rows
=
self
.
warp_rows
warp_cols
=
self
.
warp_cols
warp_cols
=
self
.
warp_cols
local_size_a
=
self
.
local_size_a
local_size_a
=
self
.
local_size_a
...
@@ -538,7 +519,7 @@ class SparseTensorCoreIntrinEmitter:
...
@@ -538,7 +519,7 @@ class SparseTensorCoreIntrinEmitter:
accum_dtype
=
self
.
accum_dtype
accum_dtype
=
self
.
accum_dtype
accum_dtype_abbrv
=
self
.
accum_dtype_abbrv
accum_dtype_abbrv
=
self
.
accum_dtype_abbrv
mma_prefix
=
self
.
mma_prefix
mma_prefix
=
self
.
mma_prefix
replicate_b
=
(
self
.
n_dim
==
16
)
replicate_b
=
self
.
n_dim
==
16
a_is_fragment
=
is_fragment
(
A_local_buf
)
a_is_fragment
=
is_fragment
(
A_local_buf
)
e_is_fragment
=
is_fragment
(
E_local_buf
)
e_is_fragment
=
is_fragment
(
E_local_buf
)
...
@@ -584,8 +565,7 @@ class SparseTensorCoreIntrinEmitter:
...
@@ -584,8 +565,7 @@ class SparseTensorCoreIntrinEmitter:
B_local_buf
.
data
,
B_local_buf
.
data
,
b_local_stride
+
j
*
local_size_b
+
lift
(
local_size_b
)
//
2
,
b_local_stride
+
j
*
local_size_b
+
lift
(
local_size_b
)
//
2
,
C_local_buf
.
data
,
C_local_buf
.
data
,
i
*
warp_cols
*
local_size_out
+
j
*
local_size_out
+
i
*
warp_cols
*
local_size_out
+
j
*
local_size_out
+
lift
(
local_size_out
)
//
2
,
lift
(
local_size_out
)
//
2
,
E_local_buf
.
data
,
# metadata
E_local_buf
.
data
,
# metadata
e_local_stride
+
i
*
local_size_e
,
# metadata offset
e_local_stride
+
i
*
local_size_e
,
# metadata offset
self
.
SPARSE_SELECTOR
,
# sparse_selector
self
.
SPARSE_SELECTOR
,
# sparse_selector
...
@@ -623,14 +603,13 @@ class SparseTensorCoreIntrinEmitter:
...
@@ -623,14 +603,13 @@ class SparseTensorCoreIntrinEmitter:
local_id
=
local_id_o
*
2
+
local_id_i
local_id
=
local_id_o
*
2
+
local_id_i
row
,
col
=
T
.
meta_var
(
mma_store_index_map
(
tx
,
local_id
))
row
,
col
=
T
.
meta_var
(
mma_store_index_map
(
tx
,
local_id
))
if
C_buf_dims
==
2
:
if
C_buf_dims
==
2
:
C_buf
[(
warp_m
*
warp_rows
+
i
)
*
M_DIM
+
row
,
C_buf
[(
warp_m
*
warp_rows
+
i
)
*
M_DIM
+
row
,
(
warp_n
*
warp_cols
+
j
)
*
n_dim
+
col
]
=
C_local_buf
[
(
warp_n
*
warp_cols
+
j
)
*
n_dim
+
i
*
(
warp_cols
*
local_size_out
)
+
j
*
local_size_out
+
local_id
col
]
=
C_local_buf
[
i
*
(
warp_cols
*
local_size_out
)
+
]
j
*
local_size_out
+
local_id
]
else
:
else
:
C_buf
[
warp_m
*
warp_rows
+
i
,
warp_n
*
warp_cols
+
j
,
row
,
C_buf
[
warp_m
*
warp_rows
+
i
,
warp_n
*
warp_cols
+
j
,
row
,
col
]
=
C_local_buf
[
col
]
=
C_local_buf
[
i
*
(
warp_cols
*
local_size_out
)
+
i
*
(
warp_cols
*
local_size_out
)
+
j
*
local_size_out
+
local_id
j
*
local_size_out
+
local_id
]
]
@
T
.
macro
@
T
.
macro
def
_warp_stmatrix_global
(
C_local_buf
,
C_buf
,
thread_binding
):
def
_warp_stmatrix_global
(
C_local_buf
,
C_buf
,
thread_binding
):
...
@@ -643,15 +622,15 @@ class SparseTensorCoreIntrinEmitter:
...
@@ -643,15 +622,15 @@ class SparseTensorCoreIntrinEmitter:
C_buf
[
C_buf
[
(
pid_m
*
BLOCK_M
+
warp_m
*
warp_rows
+
i
)
*
M_DIM
+
row
,
(
pid_m
*
BLOCK_M
+
warp_m
*
warp_rows
+
i
)
*
M_DIM
+
row
,
(
pid_n
*
BLOCK_N
+
warp_n
*
warp_cols
+
j
)
*
n_dim
+
col
,
(
pid_n
*
BLOCK_N
+
warp_n
*
warp_cols
+
j
)
*
n_dim
+
col
,
]
=
C_local_buf
[
i
*
warp_cols
*
local_size_out
+
j
*
local_size_out
+
]
=
C_local_buf
[
i
*
warp_cols
*
local_size_out
+
j
*
local_size_out
+
local_id
]
local_id
]
return
(
_warp_stmatrix_global
(
C_local_buf
,
C_buf
,
thread_binding
)
return
(
if
is_global
else
_warp_stmatrix_shared
(
C_local_buf
,
C_buf
,
thread_binding
))
_warp_stmatrix_global
(
C_local_buf
,
C_buf
,
thread_binding
)
if
is_global
else
_warp_stmatrix_shared
(
C_local_buf
,
C_buf
,
thread_binding
)
)
def
make_mma_load_layout
(
self
,
def
make_mma_load_layout
(
self
,
local_buf
:
Buffer
,
matrix
:
Literal
[
"A"
,
"B"
]
=
"A"
)
->
T
.
Fragment
:
local_buf
:
Buffer
,
matrix
:
Literal
[
"A"
,
"B"
]
=
"A"
)
->
T
.
Fragment
:
"""
"""
Create a layout function for storing MMA results into a fragment buffer.
Create a layout function for storing MMA results into a fragment buffer.
This layout is used in conjunction with `inverse_mma_store_layout` to
This layout is used in conjunction with `inverse_mma_store_layout` to
...
@@ -674,6 +653,7 @@ class SparseTensorCoreIntrinEmitter:
...
@@ -674,6 +653,7 @@ class SparseTensorCoreIntrinEmitter:
If `local_buf` is not detected to be a fragment buffer.
If `local_buf` is not detected to be a fragment buffer.
"""
"""
from
tilelang.utils
import
is_fragment
from
tilelang.utils
import
is_fragment
assert
matrix
in
[
"A"
,
"B"
],
"matrix should be either A or B"
assert
matrix
in
[
"A"
,
"B"
],
"matrix should be either A or B"
matrix_is_a
:
bool
=
matrix
==
"A"
matrix_is_a
:
bool
=
matrix
==
"A"
matrix_is_b
:
bool
=
matrix
==
"B"
matrix_is_b
:
bool
=
matrix
==
"B"
...
@@ -710,11 +690,9 @@ class SparseTensorCoreIntrinEmitter:
...
@@ -710,11 +690,9 @@ class SparseTensorCoreIntrinEmitter:
# so the b matrix expected a transposed basic layout
# so the b matrix expected a transposed basic layout
transform_func
:
Callable
=
None
transform_func
:
Callable
=
None
if
matrix_is_a
:
if
matrix_is_a
:
transform_func
=
transform_func_sr_a
if
is_sr_axis_order
else
lambda
i
,
j
:
transform_func_sr_a
(
transform_func
=
transform_func_sr_a
if
is_sr_axis_order
else
lambda
i
,
j
:
transform_func_sr_a
(
j
,
i
)
j
,
i
)
elif
matrix_is_b
:
elif
matrix_is_b
:
transform_func
=
transform_func_sr_b
if
is_sr_axis_order
else
lambda
i
,
j
:
transform_func_sr_b
(
transform_func
=
transform_func_sr_b
if
is_sr_axis_order
else
lambda
i
,
j
:
transform_func_sr_b
(
j
,
i
)
j
,
i
)
else
:
else
:
raise
ValueError
(
f
"Unsupported matrix
{
matrix
}
"
)
raise
ValueError
(
f
"Unsupported matrix
{
matrix
}
"
)
...
@@ -747,7 +725,8 @@ class SparseTensorCoreIntrinEmitter:
...
@@ -747,7 +725,8 @@ class SparseTensorCoreIntrinEmitter:
return
local_id
return
local_id
base_fragment
=
T
.
Fragment
(
base_fragment
=
T
.
Fragment
(
[
micro_size_s
,
micro_size_r
//
2
if
matrix_is_a
else
micro_size_r
]
if
is_sr_axis_order
[
micro_size_s
,
micro_size_r
//
2
if
matrix_is_a
else
micro_size_r
]
if
is_sr_axis_order
else
[
micro_size_r
//
2
if
matrix_is_a
else
micro_size_r
,
micro_size_s
],
else
[
micro_size_r
//
2
if
matrix_is_a
else
micro_size_r
,
micro_size_s
],
forward_thread_fn
=
forward_thread
,
forward_thread_fn
=
forward_thread
,
forward_index_fn
=
forward_index
,
forward_index_fn
=
forward_index
,
...
@@ -762,31 +741,19 @@ class SparseTensorCoreIntrinEmitter:
...
@@ -762,31 +741,19 @@ class SparseTensorCoreIntrinEmitter:
replicate
=
block_col_warps
if
matrix_is_a
else
block_row_warps
replicate
=
block_col_warps
if
matrix_is_a
else
block_row_warps
if
is_sr_axis_order
:
if
is_sr_axis_order
:
warp_fragment
=
base_fragment
.
repeat
([
warp_s
,
warp_r
],
warp_fragment
=
base_fragment
.
repeat
([
warp_s
,
warp_r
],
repeat_on_thread
=
False
,
lower_dim_first
=
False
)
repeat_on_thread
=
False
,
lower_dim_first
=
False
)
if
matrix_is_a
:
if
matrix_is_a
:
block_fragment
=
warp_fragment
.
repeat
([
block_s
,
1
],
block_fragment
=
warp_fragment
.
repeat
([
block_s
,
1
],
repeat_on_thread
=
True
,
lower_dim_first
=
True
).
replicate
(
replicate
)
repeat_on_thread
=
True
,
lower_dim_first
=
True
).
replicate
(
replicate
)
elif
matrix_is_b
:
elif
matrix_is_b
:
block_fragment
=
warp_fragment
.
replicate
(
replicate
).
repeat
([
block_s
,
1
],
block_fragment
=
warp_fragment
.
replicate
(
replicate
).
repeat
([
block_s
,
1
],
repeat_on_thread
=
True
,
lower_dim_first
=
True
)
repeat_on_thread
=
True
,
lower_dim_first
=
True
)
else
:
else
:
raise
ValueError
(
f
"Unsupported matrix type
{
matrix
}
"
)
raise
ValueError
(
f
"Unsupported matrix type
{
matrix
}
"
)
else
:
else
:
warp_fragment
=
base_fragment
.
repeat
([
warp_r
,
warp_s
],
warp_fragment
=
base_fragment
.
repeat
([
warp_r
,
warp_s
],
repeat_on_thread
=
False
,
lower_dim_first
=
True
)
repeat_on_thread
=
False
,
lower_dim_first
=
True
)
if
matrix_is_a
:
if
matrix_is_a
:
block_fragment
=
warp_fragment
.
repeat
([
1
,
block_s
],
block_fragment
=
warp_fragment
.
repeat
([
1
,
block_s
],
repeat_on_thread
=
True
,
lower_dim_first
=
True
).
replicate
(
replicate
)
repeat_on_thread
=
True
,
lower_dim_first
=
True
).
replicate
(
replicate
)
elif
matrix_is_b
:
elif
matrix_is_b
:
block_fragment
=
warp_fragment
.
replicate
(
replicate
).
repeat
([
1
,
block_s
],
block_fragment
=
warp_fragment
.
replicate
(
replicate
).
repeat
([
1
,
block_s
],
repeat_on_thread
=
True
,
lower_dim_first
=
True
)
repeat_on_thread
=
True
,
lower_dim_first
=
True
)
else
:
else
:
raise
ValueError
(
f
"Unsupported matrix type
{
matrix
}
"
)
raise
ValueError
(
f
"Unsupported matrix type
{
matrix
}
"
)
...
...
tilelang/intrinsics/tcgen05_macro_generator.py
View file @
29051439
...
@@ -88,9 +88,22 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
...
@@ -88,9 +88,22 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
is_m_first
:
bool
=
False
,
is_m_first
:
bool
=
False
,
thread_var
:
Var
|
None
=
None
,
thread_var
:
Var
|
None
=
None
,
):
):
super
().
__init__
(
a_dtype
,
b_dtype
,
accum_dtype
,
a_transposed
,
b_transposed
,
block_row_warps
,
super
().
__init__
(
block_col_warps
,
warp_row_tiles
,
warp_col_tiles
,
chunk
,
reduce_k
,
a_dtype
,
num_elems_per_byte
,
is_m_first
,
thread_var
)
b_dtype
,
accum_dtype
,
a_transposed
,
b_transposed
,
block_row_warps
,
block_col_warps
,
warp_row_tiles
,
warp_col_tiles
,
chunk
,
reduce_k
,
num_elems_per_byte
,
is_m_first
,
thread_var
,
)
def
_assign_a_shared_layout
(
self
,
layout
:
Layout
):
def
_assign_a_shared_layout
(
self
,
layout
:
Layout
):
self
.
a_shared_layout
=
layout
self
.
a_shared_layout
=
layout
...
@@ -137,13 +150,7 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
...
@@ -137,13 +150,7 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
else
:
else
:
raise
ValueError
(
f
"Unsupported swizzle mode:
{
layout
}
"
)
raise
ValueError
(
f
"Unsupported swizzle mode:
{
layout
}
"
)
def
tcgen05mma
(
self
,
def
tcgen05mma
(
self
,
A_buf
:
Buffer
,
B_buf
:
Buffer
,
C_local_buf
:
Buffer
,
mbar
,
clear_accum
:
PrimExpr
=
False
):
A_buf
:
Buffer
,
B_buf
:
Buffer
,
C_local_buf
:
Buffer
,
mbar
,
clear_accum
:
PrimExpr
=
False
):
if
is_tensor_memory
(
A_buf
):
if
is_tensor_memory
(
A_buf
):
return
self
.
tcgen05mma_rs
(
A_buf
,
B_buf
,
C_local_buf
,
clear_accum
)
return
self
.
tcgen05mma_rs
(
A_buf
,
B_buf
,
C_local_buf
,
clear_accum
)
...
@@ -164,22 +171,20 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
...
@@ -164,22 +171,20 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
elems_in_bits
=
DataType
(
self
.
a_dtype
).
bits
elems_in_bits
=
DataType
(
self
.
a_dtype
).
bits
elems_in_bytes
=
elems_in_bits
//
8
elems_in_bytes
=
elems_in_bits
//
8
a_swizzle_atom_elems
=
a_swizzle_mode
.
swizzle_byte_size
()
//
elems_in_bytes
a_swizzle_atom_elems
=
a_swizzle_mode
.
swizzle_byte_size
()
//
elems_in_bytes
b_swizzle_atom_elems
=
n_dim
if
b_swizzle_mode
.
is_none
(
b_swizzle_atom_elems
=
n_dim
if
b_swizzle_mode
.
is_none
()
else
b_swizzle_mode
.
swizzle_byte_size
()
//
elems_in_bytes
)
else
b_swizzle_mode
.
swizzle_byte_size
()
//
elems_in_bytes
accum_dtype_in_bits
=
DataType
(
accum_dtype
).
bits
accum_dtype_in_bits
=
DataType
(
accum_dtype
).
bits
meta
=
self
.
get_tcgen5_mma_meta
(
m_dim
,
n_dim
,
k_dim
)
meta
=
self
.
get_tcgen5_mma_meta
(
m_dim
,
n_dim
,
k_dim
)
if
len
(
meta
)
!=
5
:
if
len
(
meta
)
!=
5
:
raise
ValueError
(
raise
ValueError
(
f
"Unsupported TCGEN5MMA configuration for desc generation: M=
{
m_dim
}
, N=
{
n_dim
}
, "
f
"Unsupported TCGEN5MMA configuration for desc generation: M=
{
m_dim
}
, N=
{
n_dim
}
, "
f
"K=
{
k_dim
}
, A dtype=
{
self
.
a_dtype
}
, accum dtype=
{
self
.
accum_dtype
}
"
)
f
"K=
{
k_dim
}
, A dtype=
{
self
.
a_dtype
}
, accum dtype=
{
self
.
accum_dtype
}
"
)
atom_m
,
atom_n
,
atom_k
,
enable_ws
,
enable_2cta
=
(
int
(
x
)
for
x
in
meta
)
atom_m
,
atom_n
,
atom_k
,
enable_ws
,
enable_2cta
=
(
int
(
x
)
for
x
in
meta
)
# by default, we utilize non-swizzle layout offset
# by default, we utilize non-swizzle layout offset
a_leading_byte_offset
=
(
8
*
8
*
elems_in_bytes
)
if
a_is_k_major
else
(
8
*
m_dim
*
a_leading_byte_offset
=
(
8
*
8
*
elems_in_bytes
)
if
a_is_k_major
else
(
8
*
m_dim
*
elems_in_bytes
)
elems_in_bytes
)
a_stride_byte_offset
=
(
8
*
k_dim
*
elems_in_bytes
)
if
a_is_k_major
else
(
8
*
8
*
elems_in_bytes
)
a_stride_byte_offset
=
(
8
*
k_dim
*
elems_in_bytes
)
if
a_is_k_major
else
(
8
*
8
*
elems_in_bytes
)
if
not
a_swizzle_mode
.
is_none
():
if
not
a_swizzle_mode
.
is_none
():
# swizzle mode doesn't require LBO/SBO to be 1
# swizzle mode doesn't require LBO/SBO to be 1
...
@@ -202,11 +207,8 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
...
@@ -202,11 +207,8 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
else
:
else
:
a_stride_byte_offset
=
8
*
elems_in_bytes
*
a_swizzle_atom_elems
a_stride_byte_offset
=
8
*
elems_in_bytes
*
a_swizzle_atom_elems
b_leading_byte_offset
=
(
8
*
8
*
elems_in_bytes
)
if
b_is_k_major
else
(
8
*
n_dim
*
b_leading_byte_offset
=
(
8
*
8
*
elems_in_bytes
)
if
b_is_k_major
else
(
8
*
n_dim
*
elems_in_bytes
)
elems_in_bytes
)
b_stride_byte_offset
=
(
8
*
k_dim
*
elems_in_bytes
)
if
b_is_k_major
else
(
0
if
n_dim
==
8
else
(
8
*
8
*
elems_in_bytes
))
b_stride_byte_offset
=
(
8
*
k_dim
*
elems_in_bytes
)
if
b_is_k_major
else
(
0
if
n_dim
==
8
else
(
8
*
8
*
elems_in_bytes
))
if
not
b_swizzle_mode
.
is_none
():
if
not
b_swizzle_mode
.
is_none
():
# swizzle mode doesn't require LBO/SBO to be 1
# swizzle mode doesn't require LBO/SBO to be 1
# https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-leading-dimension-byte-offset
# https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-leading-dimension-byte-offset
...
@@ -312,21 +314,26 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
...
@@ -312,21 +314,26 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
for
ki
in
T
.
unroll
(
0
,
(
k_dim
//
micro_size_k
)):
for
ki
in
T
.
unroll
(
0
,
(
k_dim
//
micro_size_k
)):
scale_out
=
T
.
Select
(
ki
!=
0
,
1
,
T
.
Select
(
clear_accum
,
0
,
1
))
scale_out
=
T
.
Select
(
ki
!=
0
,
1
,
T
.
Select
(
clear_accum
,
0
,
1
))
A_elem_offset
=
(
A_elem_offset
=
(
ki
%
ak_atom_size
(
ki
%
ak_atom_size
)
*
micro_size_k
)
*
micro_size_k
+
i
*
atom_m
*
a_swizzle_atom_elems
+
(
+
i
*
atom_m
*
a_swizzle_atom_elems
ki
//
ak_atom_size
+
(
ki
//
ak_atom_size
)
*
m_dim
*
a_swizzle_atom_elems
)
*
m_dim
*
a_swizzle_atom_elems
if
a_is_k_major
else
i
*
atom_m
*
k_dim
+
ki
*
a_swizzle_atom_elems
*
micro_size_k
if
a_is_k_major
else
i
*
atom_m
*
k_dim
+
ki
*
a_swizzle_atom_elems
*
micro_size_k
)
B_elem_offset
=
(
ki
//
bk_atom_size
)
*
n_dim
*
b_swizzle_atom_elems
+
(
B_elem_offset
=
(
ki
%
bk_atom_size
(
ki
//
bk_atom_size
)
*
n_dim
*
b_swizzle_atom_elems
)
*
micro_size_k
+
j
*
atom_n
*
b_swizzle_atom_elems
if
b_is_k_major
else
(
+
(
ki
%
bk_atom_size
)
*
micro_size_k
ki
*
b_swizzle_atom_elems
*
micro_size_k
+
j
*
atom_n
*
+
j
*
atom_n
*
b_swizzle_atom_elems
(
k_dim
if
n_dim
//
b_swizzle_atom_elems
>
1
else
1
))
if
b_is_k_major
else
(
ki
*
b_swizzle_atom_elems
*
micro_size_k
+
j
*
atom_n
*
(
k_dim
if
n_dim
//
b_swizzle_atom_elems
>
1
else
1
)
)
)
A_byte_offset
=
A_elem_offset
*
elems_in_bytes
A_byte_offset
=
A_elem_offset
*
elems_in_bytes
B_byte_offset
=
B_elem_offset
*
elems_in_bytes
B_byte_offset
=
B_elem_offset
*
elems_in_bytes
C_offset
=
(
i
*
n_dim
+
j
*
tmem_col_step
C_offset
=
(
i
*
n_dim
+
j
*
tmem_col_step
)
*
accum_dtype_in_bits
//
32
# 32 bits per tmem bank
)
*
accum_dtype_in_bits
//
32
# 32 bits per tmem bank
T
.
ptx_tcgen05_mma_ss
(
T
.
ptx_tcgen05_mma_ss
(
a_dtype_abbrv
,
a_dtype_abbrv
,
...
@@ -373,8 +380,7 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
...
@@ -373,8 +380,7 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
"""
"""
assert
is_tensor_memory
(
tmem_buf
),
"tmem_buf must reside in tensor memory (shared.tmem)"
assert
is_tensor_memory
(
tmem_buf
),
"tmem_buf must reside in tensor memory (shared.tmem)"
if
len
(
tmem_buf
.
shape
)
!=
2
:
if
len
(
tmem_buf
.
shape
)
!=
2
:
raise
ValueError
(
raise
ValueError
(
f
"TCGEN5MMA expects a 2-D tensor-memory buffer, got shape
{
tmem_buf
.
shape
}
"
)
f
"TCGEN5MMA expects a 2-D tensor-memory buffer, got shape
{
tmem_buf
.
shape
}
"
)
m
=
int
(
tmem_buf
.
shape
[
0
])
m
=
int
(
tmem_buf
.
shape
[
0
])
n
=
int
(
tmem_buf
.
shape
[
1
])
n
=
int
(
tmem_buf
.
shape
[
1
])
...
@@ -382,14 +388,13 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
...
@@ -382,14 +388,13 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
meta
=
self
.
get_tcgen5_mma_meta
(
m
,
n
,
k
)
meta
=
self
.
get_tcgen5_mma_meta
(
m
,
n
,
k
)
if
len
(
meta
)
!=
5
:
if
len
(
meta
)
!=
5
:
raise
ValueError
(
f
"Unsupported TCGEN5MMA configuration: M=
{
m
}
, N=
{
n
}
, K=
{
k
}
, "
raise
ValueError
(
f
"A dtype=
{
self
.
a_dtype
}
, accum dtype=
{
self
.
accum_dtype
}
"
)
f
"Unsupported TCGEN5MMA configuration: M=
{
m
}
, N=
{
n
}
, K=
{
k
}
, A dtype=
{
self
.
a_dtype
}
, accum dtype=
{
self
.
accum_dtype
}
"
)
atom_m
,
atom_n
,
_
,
_
,
_
=
(
int
(
x
)
for
x
in
meta
)
atom_m
,
atom_n
,
_
,
_
,
_
=
(
int
(
x
)
for
x
in
meta
)
if
m
%
atom_m
!=
0
or
n
%
atom_n
!=
0
:
if
m
%
atom_m
!=
0
or
n
%
atom_n
!=
0
:
raise
ValueError
(
raise
ValueError
(
f
"Invalid TCGEN5MMA store layout for shape (
{
m
}
,
{
n
}
) with atoms (
{
atom_m
}
,
{
atom_n
}
)"
)
f
"Invalid TCGEN5MMA store layout for shape (
{
m
}
,
{
n
}
) with atoms (
{
atom_m
}
,
{
atom_n
}
)"
)
def
forward
(
i
:
PrimExpr
,
j
:
PrimExpr
):
def
forward
(
i
:
PrimExpr
,
j
:
PrimExpr
):
atom_idx
=
(
i
//
atom_m
)
+
(
j
//
atom_n
)
*
(
m
//
atom_m
)
atom_idx
=
(
i
//
atom_m
)
+
(
j
//
atom_n
)
*
(
m
//
atom_m
)
...
@@ -422,11 +427,11 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
...
@@ -422,11 +427,11 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
return
Layout
([
m
,
n
],
forward
)
return
Layout
([
m
,
n
],
forward
)
def
get_tcgen5_mma_meta
(
self
,
m
:
int
,
n
:
int
,
k
:
int
):
def
get_tcgen5_mma_meta
(
self
,
m
:
int
,
n
:
int
,
k
:
int
):
return
_ffi_api
.
get_tcgen5_mma_meta
(
return
_ffi_api
.
get_tcgen5_mma_meta
(
int
(
m
),
int
(
n
),
int
(
k
),
DataType
(
self
.
a_dtype
),
DataType
(
self
.
accum_dtype
))
int
(
m
),
int
(
n
),
int
(
k
),
DataType
(
self
.
a_dtype
),
DataType
(
self
.
accum_dtype
))
def
get_tcgen5_instr_desc
(
self
,
atom_m
:
int
,
atom_n
:
int
,
atom_k
:
int
,
a_is_k_major
:
bool
,
def
get_tcgen5_instr_desc
(
b_is_k_major
:
bool
,
scale_in_a
:
int
,
scale_in_b
:
int
)
->
PrimExpr
:
self
,
atom_m
:
int
,
atom_n
:
int
,
atom_k
:
int
,
a_is_k_major
:
bool
,
b_is_k_major
:
bool
,
scale_in_a
:
int
,
scale_in_b
:
int
)
->
PrimExpr
:
desc
=
_ffi_api
.
get_tcgen5_instr_desc
(
desc
=
_ffi_api
.
get_tcgen5_instr_desc
(
atom_m
,
atom_m
,
atom_n
,
atom_n
,
...
...
tilelang/intrinsics/utils.py
View file @
29051439
...
@@ -10,7 +10,7 @@ from .mma_layout import (
...
@@ -10,7 +10,7 @@ from .mma_layout import (
mma_store_32x8_to_shared_16x16_layout
,
mma_store_32x8_to_shared_16x16_layout
,
mma_store_32x2_to_shared_8x8_layout_fp64
,
mma_store_32x2_to_shared_8x8_layout_fp64
,
)
)
from
.mfma_layout
import
(
thread_id_shared_access_64x4_to_16x16_layout_C_n_m
)
from
.mfma_layout
import
thread_id_shared_access_64x4_to_16x16_layout_C_n_m
from
.mma_layout
import
get_swizzle_layout
# noqa: F401
from
.mma_layout
import
get_swizzle_layout
# noqa: F401
from
.mma_layout
import
make_mma_swizzle_layout
# noqa: F401
from
.mma_layout
import
make_mma_swizzle_layout
# noqa: F401
...
...
tilelang/intrinsics/wgmma_macro_generator.py
View file @
29051439
...
@@ -15,9 +15,11 @@ from tilelang.layout import (
...
@@ -15,9 +15,11 @@ from tilelang.layout import (
make_linear_layout
,
make_linear_layout
,
)
)
from
tvm.runtime
import
convert
from
tvm.runtime
import
convert
from
tilelang.intrinsics.mma_layout
import
(
shared_16x8_to_mma_32x4_layout_sr_a
,
from
tilelang.intrinsics.mma_layout
import
(
shared_16x16_to_mma_32x8_layout_sr_a
,
shared_16x8_to_mma_32x4_layout_sr_a
,
shared_16x32_to_mma_32x16_layout_sr_a
)
shared_16x16_to_mma_32x8_layout_sr_a
,
shared_16x32_to_mma_32x16_layout_sr_a
,
)
lift
=
convert
lift
=
convert
...
@@ -96,9 +98,22 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
...
@@ -96,9 +98,22 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
is_m_first
:
bool
|
None
=
False
,
is_m_first
:
bool
|
None
=
False
,
thread_var
:
Var
|
None
=
None
,
thread_var
:
Var
|
None
=
None
,
):
):
super
().
__init__
(
a_dtype
,
b_dtype
,
accum_dtype
,
a_transposed
,
b_transposed
,
block_row_warps
,
super
().
__init__
(
block_col_warps
,
warp_row_tiles
,
warp_col_tiles
,
chunk
,
reduce_k
,
a_dtype
,
num_elems_per_byte
,
is_m_first
,
thread_var
)
b_dtype
,
accum_dtype
,
a_transposed
,
b_transposed
,
block_row_warps
,
block_col_warps
,
warp_row_tiles
,
warp_col_tiles
,
chunk
,
reduce_k
,
num_elems_per_byte
,
is_m_first
,
thread_var
,
)
self
.
_initialize_wgmma_prefix
(
self
.
n_dim
)
self
.
_initialize_wgmma_prefix
(
self
.
n_dim
)
def
_assign_a_shared_layout
(
self
,
layout
:
Layout
):
def
_assign_a_shared_layout
(
self
,
layout
:
Layout
):
...
@@ -112,12 +127,12 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
...
@@ -112,12 +127,12 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
def
_initialize_wgmma_prefix
(
self
,
n_dim
:
int
=
16
):
def
_initialize_wgmma_prefix
(
self
,
n_dim
:
int
=
16
):
inst_m
,
inst_n
=
64
,
gcd
(
self
.
warp_col_tiles
,
256
)
inst_m
,
inst_n
=
64
,
gcd
(
self
.
warp_col_tiles
,
256
)
assert
inst_n
%
8
==
0
,
(
assert
inst_n
%
8
==
0
,
(
f
"inst_n must be a multiple of 8, got
{
inst_n
}
"
f
"inst_n must be a multiple of 8, got
{
inst_n
}
(block_col_warps=
{
self
.
block_col_warps
}
, warp_col_tiles=
{
self
.
warp_col_tiles
}
)
"
f
"(block_col_warps=
{
self
.
block_col_warps
}
, warp_col_tiles=
{
self
.
warp_col_tiles
}
)"
)
)
# Validate inst_n: Hopper WGMMA supports n in [8, 256] and multiple of 8
# Validate inst_n: Hopper WGMMA supports n in [8, 256] and multiple of 8
assert
8
<=
inst_n
<=
256
,
(
assert
8
<=
inst_n
<=
256
,
(
f
"inst_n must be within [8, 256], got
{
inst_n
}
"
f
"inst_n must be within [8, 256], got
{
inst_n
}
(block_col_warps=
{
self
.
block_col_warps
}
, warp_col_tiles=
{
self
.
warp_col_tiles
}
)
"
f
"(block_col_warps=
{
self
.
block_col_warps
}
, warp_col_tiles=
{
self
.
warp_col_tiles
}
)"
)
)
# 256 bits per instruction
# 256 bits per instruction
inst_k
=
256
//
DataType
(
self
.
a_dtype
).
bits
inst_k
=
256
//
DataType
(
self
.
a_dtype
).
bits
self
.
wgmma_inst_m
=
inst_m
self
.
wgmma_inst_m
=
inst_m
...
@@ -160,13 +175,9 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
...
@@ -160,13 +175,9 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
else
:
else
:
raise
ValueError
(
f
"Unsupported swizzle mode:
{
layout
}
"
)
raise
ValueError
(
f
"Unsupported swizzle mode:
{
layout
}
"
)
def
wgmma
(
self
,
def
wgmma
(
A_region
:
BufferRegion
,
self
,
A_region
:
BufferRegion
,
B_region
:
BufferRegion
,
C_region
:
BufferRegion
,
clear_accum
:
PrimExpr
=
False
,
wg_wait
:
int
=
0
B_region
:
BufferRegion
,
):
C_region
:
BufferRegion
,
clear_accum
:
PrimExpr
=
False
,
wg_wait
:
int
=
0
):
if
is_fragment
(
A_region
):
if
is_fragment
(
A_region
):
return
self
.
wgmma_rs
(
A_region
,
B_region
,
C_region
,
clear_accum
,
wg_wait
)
return
self
.
wgmma_rs
(
A_region
,
B_region
,
C_region
,
clear_accum
,
wg_wait
)
...
@@ -195,16 +206,13 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
...
@@ -195,16 +206,13 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
elems_in_bytes
=
elems_in_bits
//
8
elems_in_bytes
=
elems_in_bits
//
8
a_swizzle_atom_elems
=
a_swizzle_mode
.
swizzle_byte_size
()
//
elems_in_bytes
a_swizzle_atom_elems
=
a_swizzle_mode
.
swizzle_byte_size
()
//
elems_in_bytes
b_swizzle_atom_elems
=
n_dim
if
b_swizzle_mode
.
is_none
(
b_swizzle_atom_elems
=
n_dim
if
b_swizzle_mode
.
is_none
()
else
b_swizzle_mode
.
swizzle_byte_size
()
//
elems_in_bytes
)
else
b_swizzle_mode
.
swizzle_byte_size
()
//
elems_in_bytes
accum_bits
=
DataType
(
accum_dtype
).
bits
accum_bits
=
DataType
(
accum_dtype
).
bits
accum_regs
=
((
m_dim
//
64
)
*
warp_cols
*
local_size_out
*
accum_bits
+
31
)
//
32
accum_regs
=
((
m_dim
//
64
)
*
warp_cols
*
local_size_out
*
accum_bits
+
31
)
//
32
# by default, we utilize non-swizzle layout offset
# by default, we utilize non-swizzle layout offset
a_leading_byte_offset
=
(
8
*
8
*
elems_in_bytes
)
if
a_is_k_major
else
(
8
*
m_dim
*
a_leading_byte_offset
=
(
8
*
8
*
elems_in_bytes
)
if
a_is_k_major
else
(
8
*
m_dim
*
elems_in_bytes
)
elems_in_bytes
)
a_stride_byte_offset
=
(
8
*
k_dim
*
elems_in_bytes
)
if
a_is_k_major
else
(
8
*
8
*
elems_in_bytes
)
a_stride_byte_offset
=
(
8
*
k_dim
*
elems_in_bytes
)
if
a_is_k_major
else
(
8
*
8
*
elems_in_bytes
)
if
not
a_swizzle_mode
.
is_none
():
if
not
a_swizzle_mode
.
is_none
():
# swizzle mode doesn't require LBO/SBO to be 1
# swizzle mode doesn't require LBO/SBO to be 1
...
@@ -220,19 +228,15 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
...
@@ -220,19 +228,15 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
if
a_m_axis_atoms
<=
1
:
if
a_m_axis_atoms
<=
1
:
a_leading_byte_offset
=
0
a_leading_byte_offset
=
0
else
:
else
:
a_leading_byte_offset
=
8
*
a_swizzle_mode
.
swizzle_atom_size
()
*
(
a_leading_byte_offset
=
8
*
a_swizzle_mode
.
swizzle_atom_size
()
*
(
a_swizzle_mode
.
swizzle_byte_size
()
//
elems_in_bytes
)
a_swizzle_mode
.
swizzle_byte_size
()
//
elems_in_bytes
)
if
a_m_axis_atoms
<=
1
:
if
a_m_axis_atoms
<=
1
:
a_stride_byte_offset
=
8
*
elems_in_bytes
*
m_dim
a_stride_byte_offset
=
8
*
elems_in_bytes
*
m_dim
else
:
else
:
a_stride_byte_offset
=
8
*
elems_in_bytes
*
a_swizzle_atom_elems
a_stride_byte_offset
=
8
*
elems_in_bytes
*
a_swizzle_atom_elems
b_leading_byte_offset
=
(
8
*
8
*
elems_in_bytes
)
if
b_is_k_major
else
(
8
*
n_dim
*
b_leading_byte_offset
=
(
8
*
8
*
elems_in_bytes
)
if
b_is_k_major
else
(
8
*
n_dim
*
elems_in_bytes
)
elems_in_bytes
)
b_stride_byte_offset
=
(
8
*
k_dim
*
elems_in_bytes
)
if
b_is_k_major
else
(
0
if
n_dim
==
8
else
(
8
*
8
*
elems_in_bytes
))
b_stride_byte_offset
=
(
8
*
k_dim
*
elems_in_bytes
)
if
b_is_k_major
else
(
0
if
n_dim
==
8
else
(
8
*
8
*
elems_in_bytes
))
if
not
b_swizzle_mode
.
is_none
():
if
not
b_swizzle_mode
.
is_none
():
# swizzle mode doesn't require LBO/SBO to be 1
# swizzle mode doesn't require LBO/SBO to be 1
# https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-leading-dimension-byte-offset
# https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-leading-dimension-byte-offset
...
@@ -275,12 +279,8 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
...
@@ -275,12 +279,8 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
desc_a
=
T
.
alloc_wgmma_desc
()
desc_a
=
T
.
alloc_wgmma_desc
()
desc_b
=
T
.
alloc_wgmma_desc
()
desc_b
=
T
.
alloc_wgmma_desc
()
T
.
initialize_wgmma_descriptor
(
desc_a
,
A_ptr
,
a_swizzle_mode
,
T
.
initialize_wgmma_descriptor
(
desc_a
,
A_ptr
,
a_swizzle_mode
,
int
(
a_leading_byte_offset
>>
4
),
int
(
a_stride_byte_offset
>>
4
))
int
(
a_leading_byte_offset
>>
4
),
T
.
initialize_wgmma_descriptor
(
desc_b
,
B_ptr
,
b_swizzle_mode
,
int
(
b_leading_byte_offset
>>
4
),
int
(
b_stride_byte_offset
>>
4
))
int
(
a_stride_byte_offset
>>
4
))
T
.
initialize_wgmma_descriptor
(
desc_b
,
B_ptr
,
b_swizzle_mode
,
int
(
b_leading_byte_offset
>>
4
),
int
(
b_stride_byte_offset
>>
4
))
T
.
warpgroup_fence_operand
(
C_buf
,
num_regs
=
accum_regs
)
T
.
warpgroup_fence_operand
(
C_buf
,
num_regs
=
accum_regs
)
T
.
warpgroup_arrive
()
T
.
warpgroup_arrive
()
...
@@ -291,21 +291,41 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
...
@@ -291,21 +291,41 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
warp_i
=
(
warp_m
//
4
)
*
num_inst_m
+
i
warp_i
=
(
warp_m
//
4
)
*
num_inst_m
+
i
warp_j
=
warp_n
*
num_inst_n
+
j
warp_j
=
warp_n
*
num_inst_n
+
j
A_offset
=
(
A_offset
=
(
ki
%
ak_atom_size
(
ki
%
ak_atom_size
)
*
micro_size_k
)
*
micro_size_k
+
warp_i
*
64
*
a_swizzle_atom_elems
+
(
+
warp_i
*
64
*
a_swizzle_atom_elems
ki
//
ak_atom_size
+
(
ki
//
ak_atom_size
)
*
m_dim
*
a_swizzle_atom_elems
)
*
m_dim
*
a_swizzle_atom_elems
if
a_is_k_major
else
warp_i
*
64
*
k_dim
+
ki
*
a_swizzle_atom_elems
*
micro_size_k
if
a_is_k_major
B_offset
=
(
ki
//
bk_atom_size
)
*
n_dim
*
b_swizzle_atom_elems
+
(
else
warp_i
*
64
*
k_dim
+
ki
*
a_swizzle_atom_elems
*
micro_size_k
ki
%
bk_atom_size
)
)
*
micro_size_k
+
warp_j
*
wgmma_inst_n
*
b_swizzle_atom_elems
if
b_is_k_major
else
(
B_offset
=
(
ki
*
b_swizzle_atom_elems
*
micro_size_k
+
warp_j
*
wgmma_inst_n
*
(
ki
//
bk_atom_size
)
*
n_dim
*
b_swizzle_atom_elems
(
k_dim
if
n_dim
//
b_swizzle_atom_elems
>
1
else
1
))
+
(
ki
%
bk_atom_size
)
*
micro_size_k
+
warp_j
*
wgmma_inst_n
*
b_swizzle_atom_elems
if
b_is_k_major
else
(
ki
*
b_swizzle_atom_elems
*
micro_size_k
+
warp_j
*
wgmma_inst_n
*
(
k_dim
if
n_dim
//
b_swizzle_atom_elems
>
1
else
1
)
)
)
C_offset
=
i
*
warp_cols
*
local_size_out
+
j
*
warp_cols
*
local_size_out
//
num_inst_n
# 4 warps as an unit
C_offset
=
i
*
warp_cols
*
local_size_out
+
j
*
warp_cols
*
local_size_out
//
num_inst_n
# 4 warps as an unit
T
.
ptx_wgmma_ss
(
accum_dtype
,
wgmma_prefix
,
a_is_k_major
,
b_is_k_major
,
T
.
ptx_wgmma_ss
(
a_dtype_abbrv
,
b_dtype_abbrv
,
accum_dtype_abbrv
,
desc_a
.
data
,
accum_dtype
,
(
A_offset
*
elems_in_bytes
)
>>
4
,
desc_b
.
data
,
wgmma_prefix
,
(
B_offset
*
elems_in_bytes
)
>>
4
,
C_buf
.
data
,
C_offset
,
a_is_k_major
,
scale_out
,
scale_in_a
,
scale_in_b
)
b_is_k_major
,
a_dtype_abbrv
,
b_dtype_abbrv
,
accum_dtype_abbrv
,
desc_a
.
data
,
(
A_offset
*
elems_in_bytes
)
>>
4
,
desc_b
.
data
,
(
B_offset
*
elems_in_bytes
)
>>
4
,
C_buf
.
data
,
C_offset
,
scale_out
,
scale_in_a
,
scale_in_b
,
)
T
.
warpgroup_commit_batch
()
T
.
warpgroup_commit_batch
()
if
wg_wait
>=
0
:
if
wg_wait
>=
0
:
...
@@ -314,12 +334,9 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
...
@@ -314,12 +334,9 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
return
_warp_mma
(
A_ptr
,
B_ptr
,
C_buf
)
return
_warp_mma
(
A_ptr
,
B_ptr
,
C_buf
)
def
wgmma_rs
(
self
,
def
wgmma_rs
(
A_region
:
BufferRegion
,
self
,
A_region
:
BufferRegion
,
B_region
:
BufferRegion
,
C_region
:
BufferRegion
,
clear_accum
:
PrimExpr
=
False
,
wg_wait
:
int
=
0
B_region
:
BufferRegion
,
):
C_region
:
BufferRegion
,
clear_accum
:
PrimExpr
=
False
,
wg_wait
:
int
=
0
):
local_size_a
=
self
.
local_size_a
local_size_a
=
self
.
local_size_a
local_size_out
=
self
.
local_size_out
local_size_out
=
self
.
local_size_out
a_dtype_abbrv
=
self
.
a_dtype_abbrv
a_dtype_abbrv
=
self
.
a_dtype_abbrv
...
@@ -344,14 +361,10 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
...
@@ -344,14 +361,10 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
b_is_k_major
=
self
.
b_transposed
b_is_k_major
=
self
.
b_transposed
b_swizzle_mode
=
self
.
_determinate_swizzle_mode
(
B_region
,
self
.
b_shared_layout
)
b_swizzle_mode
=
self
.
_determinate_swizzle_mode
(
B_region
,
self
.
b_shared_layout
)
b_swizzle_atom_elems
=
n_dim
if
b_swizzle_mode
.
is_none
(
b_swizzle_atom_elems
=
n_dim
if
b_swizzle_mode
.
is_none
()
else
b_swizzle_mode
.
swizzle_byte_size
()
//
elems_in_bytes
)
else
b_swizzle_mode
.
swizzle_byte_size
()
//
elems_in_bytes
b_leading_byte_offset
=
(
8
*
8
*
elems_in_bytes
)
if
b_is_k_major
else
(
8
*
n_dim
*
elems_in_bytes
)
b_leading_byte_offset
=
(
8
*
8
*
elems_in_bytes
)
if
b_is_k_major
else
(
8
*
n_dim
*
b_stride_byte_offset
=
(
8
*
k_dim
*
elems_in_bytes
)
if
b_is_k_major
else
(
0
if
n_dim
==
8
else
(
8
*
8
*
elems_in_bytes
))
elems_in_bytes
)
b_stride_byte_offset
=
(
8
*
k_dim
*
elems_in_bytes
)
if
b_is_k_major
else
(
0
if
n_dim
==
8
else
(
8
*
8
*
elems_in_bytes
))
if
not
b_swizzle_mode
.
is_none
():
if
not
b_swizzle_mode
.
is_none
():
# swizzle mode doesn't require LBO/SBO to be 1
# swizzle mode doesn't require LBO/SBO to be 1
# https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-leading-dimension-byte-offset
# https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-leading-dimension-byte-offset
...
@@ -390,9 +403,7 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
...
@@ -390,9 +403,7 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
tx
,
warp_n
,
warp_m
=
self
.
extract_thread_binding
(
thread_binding
)
tx
,
warp_n
,
warp_m
=
self
.
extract_thread_binding
(
thread_binding
)
desc_b
=
T
.
alloc_wgmma_desc
()
desc_b
=
T
.
alloc_wgmma_desc
()
T
.
initialize_wgmma_descriptor
(
desc_b
,
B_ptr
,
b_swizzle_mode
,
T
.
initialize_wgmma_descriptor
(
desc_b
,
B_ptr
,
b_swizzle_mode
,
int
(
b_leading_byte_offset
>>
4
),
int
(
b_stride_byte_offset
>>
4
))
int
(
b_leading_byte_offset
>>
4
),
int
(
b_stride_byte_offset
>>
4
))
T
.
warpgroup_fence_operand
(
A_buf
,
num_regs
=
a_regs
)
T
.
warpgroup_fence_operand
(
A_buf
,
num_regs
=
a_regs
)
T
.
warpgroup_fence_operand
(
C_buf
,
num_regs
=
accum_regs
)
T
.
warpgroup_fence_operand
(
C_buf
,
num_regs
=
accum_regs
)
T
.
warpgroup_arrive
()
T
.
warpgroup_arrive
()
...
@@ -405,11 +416,15 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
...
@@ -405,11 +416,15 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
A_offset
=
ki
*
warp_rows
*
local_size_a
+
i
*
local_size_a
A_offset
=
ki
*
warp_rows
*
local_size_a
+
i
*
local_size_a
B_offset
=
(
B_offset
=
(
ki
//
bk_atom_size
(
ki
//
bk_atom_size
)
*
n_dim
*
b_swizzle_atom_elems
)
*
n_dim
*
b_swizzle_atom_elems
+
warp_j
*
wgmma_inst_n
*
b_swizzle_atom_elems
+
(
+
warp_j
*
wgmma_inst_n
*
b_swizzle_atom_elems
ki
%
bk_atom_size
)
*
micro_size_k
if
b_is_k_major
else
(
+
(
ki
%
bk_atom_size
)
*
micro_size_k
ki
*
b_swizzle_atom_elems
*
micro_size_k
+
warp_j
*
wgmma_inst_n
*
if
b_is_k_major
(
k_dim
if
n_dim
//
b_swizzle_atom_elems
>
1
else
1
))
else
(
ki
*
b_swizzle_atom_elems
*
micro_size_k
+
warp_j
*
wgmma_inst_n
*
(
k_dim
if
n_dim
//
b_swizzle_atom_elems
>
1
else
1
)
)
)
C_offset
=
i
*
warp_cols
*
local_size_out
+
j
*
warp_cols
*
local_size_out
//
num_inst_n
# 4 warps as an unit
C_offset
=
i
*
warp_cols
*
local_size_out
+
j
*
warp_cols
*
local_size_out
//
num_inst_n
# 4 warps as an unit
T
.
ptx_wgmma_rs
(
T
.
ptx_wgmma_rs
(
accum_dtype
,
accum_dtype
,
...
@@ -460,6 +475,7 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
...
@@ -460,6 +475,7 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
If `local_buf` is not detected to be a fragment buffer.
If `local_buf` is not detected to be a fragment buffer.
"""
"""
from
tilelang.utils
import
is_fragment
from
tilelang.utils
import
is_fragment
assert
matrix
in
[
"A"
],
"matrix should be A for WGMMA"
assert
matrix
in
[
"A"
],
"matrix should be A for WGMMA"
dtype
=
self
.
a_dtype
dtype
=
self
.
a_dtype
dtype_bits
=
DataType
(
dtype
).
bits
dtype_bits
=
DataType
(
dtype
).
bits
...
@@ -488,8 +504,7 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
...
@@ -488,8 +504,7 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
# the layout of mma.sync is row.col.
# the layout of mma.sync is row.col.
# so the b matrix expected a transposed basic layout
# so the b matrix expected a transposed basic layout
transform_func
:
Callable
=
None
transform_func
:
Callable
=
None
transform_func
=
transform_func_sr_a
if
is_sr_axis_order
else
lambda
i
,
j
:
transform_func_sr_a
(
transform_func
=
transform_func_sr_a
if
is_sr_axis_order
else
lambda
i
,
j
:
transform_func_sr_a
(
j
,
i
)
j
,
i
)
assert
is_fragment
(
local_buf
),
f
"local_buf must be a fragment, but got
{
local_buf
.
scope
()
}
"
assert
is_fragment
(
local_buf
),
f
"local_buf must be a fragment, but got
{
local_buf
.
scope
()
}
"
...
@@ -531,20 +546,12 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
...
@@ -531,20 +546,12 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
replicate
=
block_col_warps
replicate
=
block_col_warps
if
is_sr_axis_order
:
if
is_sr_axis_order
:
warp_fragment
=
base_fragment
.
repeat
([
block_s
,
1
],
warp_fragment
=
base_fragment
.
repeat
([
block_s
,
1
],
repeat_on_thread
=
True
,
lower_dim_first
=
False
).
replicate
(
replicate
)
repeat_on_thread
=
True
,
block_fragment
=
warp_fragment
.
repeat
([
warp_s
,
warp_r
],
repeat_on_thread
=
False
,
lower_dim_first
=
False
)
lower_dim_first
=
False
).
replicate
(
replicate
)
block_fragment
=
warp_fragment
.
repeat
([
warp_s
,
warp_r
],
repeat_on_thread
=
False
,
lower_dim_first
=
False
)
else
:
else
:
# rs condition, transposed_a matrix
# rs condition, transposed_a matrix
warp_fragment
=
base_fragment
.
repeat
([
1
,
block_s
],
warp_fragment
=
base_fragment
.
repeat
([
1
,
block_s
],
repeat_on_thread
=
True
,
lower_dim_first
=
False
).
replicate
(
replicate
)
repeat_on_thread
=
True
,
block_fragment
=
warp_fragment
.
repeat
([
warp_r
,
warp_s
],
repeat_on_thread
=
False
,
lower_dim_first
=
True
)
lower_dim_first
=
False
).
replicate
(
replicate
)
block_fragment
=
warp_fragment
.
repeat
([
warp_r
,
warp_s
],
repeat_on_thread
=
False
,
lower_dim_first
=
True
)
return
block_fragment
return
block_fragment
...
...
tilelang/ir.py
View file @
29051439
...
@@ -7,23 +7,19 @@ from tilelang import _ffi_api
...
@@ -7,23 +7,19 @@ from tilelang import _ffi_api
@
tvm_ffi
.
register_object
(
"tl.Fill"
)
@
tvm_ffi
.
register_object
(
"tl.Fill"
)
class
Fill
(
Node
,
Scriptable
):
class
Fill
(
Node
,
Scriptable
):
...
...
@
tvm_ffi
.
register_object
(
"tl.AtomicAdd"
)
@
tvm_ffi
.
register_object
(
"tl.AtomicAdd"
)
class
AtomicAdd
(
Node
,
Scriptable
):
class
AtomicAdd
(
Node
,
Scriptable
):
...
...
@
tvm_ffi
.
register_object
(
"tl.Copy"
)
@
tvm_ffi
.
register_object
(
"tl.Copy"
)
class
Copy
(
Node
,
Scriptable
):
class
Copy
(
Node
,
Scriptable
):
...
...
@
tvm_ffi
.
register_object
(
"tl.Conv2DIm2Col"
)
@
tvm_ffi
.
register_object
(
"tl.Conv2DIm2Col"
)
class
Conv2DIm2ColOp
(
Node
,
Scriptable
):
class
Conv2DIm2ColOp
(
Node
,
Scriptable
):
...
...
@
tvm_ffi
.
register_object
(
"tl.GemmWarpPolicy"
)
@
tvm_ffi
.
register_object
(
"tl.GemmWarpPolicy"
)
...
@@ -32,10 +28,8 @@ class GemmWarpPolicy(Node, Scriptable):
...
@@ -32,10 +28,8 @@ class GemmWarpPolicy(Node, Scriptable):
m_warp
:
int
m_warp
:
int
n_warp
:
int
n_warp
:
int
def
compute_warp_partition
(
self
,
M
:
int
,
N
:
int
,
block_size
:
int
,
target
:
Target
,
def
compute_warp_partition
(
self
,
M
:
int
,
N
:
int
,
block_size
:
int
,
target
:
Target
,
is_wgmma
:
bool
):
is_wgmma
:
bool
):
_ffi_api
.
GemmWarpPolicyComputeWarpPartition
(
self
,
int
(
M
),
int
(
N
),
int
(
block_size
),
target
,
is_wgmma
)
_ffi_api
.
GemmWarpPolicyComputeWarpPartition
(
self
,
int
(
M
),
int
(
N
),
int
(
block_size
),
target
,
is_wgmma
)
return
self
.
m_warp
,
self
.
n_warp
return
self
.
m_warp
,
self
.
n_warp
...
@@ -45,48 +39,38 @@ class GemmSPWarpPolicy(Node, Scriptable):
...
@@ -45,48 +39,38 @@ class GemmSPWarpPolicy(Node, Scriptable):
m_warp
:
int
m_warp
:
int
n_warp
:
int
n_warp
:
int
def
compute_warp_partition
(
self
,
M
:
int
,
N
:
int
,
block_size
:
int
,
target
:
Target
,
def
compute_warp_partition
(
self
,
M
:
int
,
N
:
int
,
block_size
:
int
,
target
:
Target
,
is_wgmma
:
bool
,
bits
:
int
):
is_wgmma
:
bool
,
bits
:
int
):
_ffi_api
.
GemmSPWarpPolicyComputeWarpPartition
(
self
,
int
(
M
),
int
(
N
),
int
(
block_size
),
target
,
is_wgmma
,
bits
)
_ffi_api
.
GemmSPWarpPolicyComputeWarpPartition
(
self
,
int
(
M
),
int
(
N
),
int
(
block_size
),
target
,
is_wgmma
,
bits
)
return
self
.
m_warp
,
self
.
n_warp
return
self
.
m_warp
,
self
.
n_warp
@
tvm_ffi
.
register_object
(
"tl.Gemm"
)
@
tvm_ffi
.
register_object
(
"tl.Gemm"
)
class
Gemm
(
Node
,
Scriptable
):
class
Gemm
(
Node
,
Scriptable
):
...
...
@
tvm_ffi
.
register_object
(
"tl.GemmSP"
)
@
tvm_ffi
.
register_object
(
"tl.GemmSP"
)
class
GemmSP
(
Node
,
Scriptable
):
class
GemmSP
(
Node
,
Scriptable
):
...
...
@
tvm_ffi
.
register_object
(
"tl.FinalizeReducerOp"
)
@
tvm_ffi
.
register_object
(
"tl.FinalizeReducerOp"
)
class
FinalizeReducerOp
(
Node
,
Scriptable
):
class
FinalizeReducerOp
(
Node
,
Scriptable
):
...
...
@
tvm_ffi
.
register_object
(
"tl.ParallelOp"
)
@
tvm_ffi
.
register_object
(
"tl.ParallelOp"
)
class
ParallelOp
(
Node
,
Scriptable
):
class
ParallelOp
(
Node
,
Scriptable
):
...
...
@
tvm_ffi
.
register_object
(
"tl.ReduceOp"
)
@
tvm_ffi
.
register_object
(
"tl.ReduceOp"
)
class
ReduceOp
(
Node
,
Scriptable
):
class
ReduceOp
(
Node
,
Scriptable
):
...
...
@
tvm_ffi
.
register_object
(
"tl.CumSumOp"
)
@
tvm_ffi
.
register_object
(
"tl.CumSumOp"
)
class
CumSumOp
(
Node
,
Scriptable
):
class
CumSumOp
(
Node
,
Scriptable
):
...
...
@
tvm_ffi
.
register_object
(
"tl.RegionOp"
)
@
tvm_ffi
.
register_object
(
"tl.RegionOp"
)
class
RegionOp
(
Node
,
Scriptable
):
class
RegionOp
(
Node
,
Scriptable
):
...
...
@
tvm_ffi
.
register_object
(
"tl.ReduceType"
)
@
tvm_ffi
.
register_object
(
"tl.ReduceType"
)
class
ReduceType
(
Node
,
Scriptable
):
class
ReduceType
(
Node
,
Scriptable
):
...
...
Prev
1
…
15
16
17
18
19
20
21
22
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment