Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
29051439
"include/ck/utility/math.hpp" did not exist on "bbcb67d0aac81b51336981713662a726875ebd58"
Unverified
Commit
29051439
authored
Dec 12, 2025
by
Lei Wang
Committed by
GitHub
Dec 12, 2025
Browse files
[Lint] Phaseout Yapf format and embrace ruff format (#1417)
parent
e84b24bc
Changes
467
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
479 additions
and
652 deletions
+479
-652
tilelang/contrib/hipcc.py
tilelang/contrib/hipcc.py
+2
-7
tilelang/contrib/nvcc.py
tilelang/contrib/nvcc.py
+12
-25
tilelang/contrib/nvrtc.py
tilelang/contrib/nvrtc.py
+10
-11
tilelang/contrib/rocm.py
tilelang/contrib/rocm.py
+5
-2
tilelang/engine/lower.py
tilelang/engine/lower.py
+17
-22
tilelang/engine/param.py
tilelang/engine/param.py
+3
-0
tilelang/engine/phase.py
tilelang/engine/phase.py
+6
-12
tilelang/env.py
tilelang/env.py
+33
-39
tilelang/intrinsics/mfma_layout.py
tilelang/intrinsics/mfma_layout.py
+9
-9
tilelang/intrinsics/mfma_macro_generator.py
tilelang/intrinsics/mfma_macro_generator.py
+88
-107
tilelang/intrinsics/mma_layout.py
tilelang/intrinsics/mma_layout.py
+11
-11
tilelang/intrinsics/mma_macro_generator.py
tilelang/intrinsics/mma_macro_generator.py
+50
-93
tilelang/intrinsics/mma_sm70_layout.py
tilelang/intrinsics/mma_sm70_layout.py
+3
-5
tilelang/intrinsics/mma_sm70_macro_generator.py
tilelang/intrinsics/mma_sm70_macro_generator.py
+21
-54
tilelang/intrinsics/mma_sp_layout.py
tilelang/intrinsics/mma_sp_layout.py
+9
-18
tilelang/intrinsics/mma_sp_macro_generator.py
tilelang/intrinsics/mma_sp_macro_generator.py
+41
-74
tilelang/intrinsics/tcgen05_macro_generator.py
tilelang/intrinsics/tcgen05_macro_generator.py
+49
-44
tilelang/intrinsics/utils.py
tilelang/intrinsics/utils.py
+1
-1
tilelang/intrinsics/wgmma_macro_generator.py
tilelang/intrinsics/wgmma_macro_generator.py
+93
-86
tilelang/ir.py
tilelang/ir.py
+16
-32
No files found.
tilelang/contrib/hipcc.py
View file @
29051439
...
...
@@ -16,12 +16,7 @@ from tvm.base import py_str
from
tvm.contrib.rocm
import
get_rocm_arch
,
find_rocm_path
def
compile_hip
(
code
,
target_format
=
"hsaco"
,
arch
=
None
,
options
=
None
,
path_target
=
None
,
verbose
=
False
):
def
compile_hip
(
code
,
target_format
=
"hsaco"
,
arch
=
None
,
options
=
None
,
path_target
=
None
,
verbose
=
False
):
"""Compile HIP code with hipcc.
Parameters
...
...
@@ -61,7 +56,7 @@ def compile_hip(code,
file_target
=
path_target
if
path_target
else
temp_target
cmd
=
[
"hipcc"
]
cmd
+=
[
"-O3"
,
'
-c
'
]
cmd
+=
[
"-O3"
,
"
-c
"
]
if
isinstance
(
arch
,
str
):
cmd
+=
[
f
"--offload-arch=
{
arch
}
"
]
if
target_format
==
"hsaco"
:
...
...
tilelang/contrib/nvcc.py
View file @
29051439
# pylint: disable=invalid-name
# modified from apache tvm python/tvm/contrib/nvcc.py
"""Utility to invoke nvcc compiler in the system"""
from
__future__
import
annotations
import
os
...
...
@@ -18,12 +19,7 @@ from tvm.base import py_str
from
tvm.contrib
import
utils
def
compile_cuda
(
code
,
target_format
=
"ptx"
,
arch
=
None
,
options
=
None
,
path_target
=
None
,
verbose
=
False
):
def
compile_cuda
(
code
,
target_format
=
"ptx"
,
arch
=
None
,
options
=
None
,
path_target
=
None
,
verbose
=
False
):
"""Compile cuda code with NVCC from env.
Parameters
...
...
@@ -67,7 +63,7 @@ def compile_cuda(code,
temp_target
=
temp
.
relpath
(
f
"
{
file_name
}
.
{
target_format
}
"
)
pass_context
=
tvm
.
get_global_func
(
"transform.GetCurrentPassContext"
)()
kernels_output_dir
=
(
pass_context
.
config
.
get
(
"cuda.kernels_output_dir"
,
None
)
)
kernels_output_dir
=
pass_context
.
config
.
get
(
"cuda.kernels_output_dir"
,
None
)
if
kernels_output_dir
is
not
None
:
if
not
os
.
path
.
isdir
(
kernels_output_dir
):
os
.
makedirs
(
kernels_output_dir
)
...
...
@@ -114,10 +110,7 @@ def compile_cuda(code,
print
(
py_str
(
out
))
if
proc
.
returncode
!=
0
:
msg
=
f
"
{
code
}
\n
"
\
f
"Compilation error:
\n
"
\
f
"
{
py_str
(
out
)
}
\n
"
\
f
"Command:
{
' '
.
join
(
cmd
)
}
\n
"
msg
=
f
"
{
code
}
\n
Compilation error:
\n
{
py_str
(
out
)
}
\n
Command:
{
' '
.
join
(
cmd
)
}
\n
"
raise
RuntimeError
(
msg
)
with
open
(
file_target
,
"rb"
)
as
f
:
...
...
@@ -165,6 +158,7 @@ def default_compile_options(compile_flags: list[str] | None = None) -> list[str]
# (e.g., multiple "-gencode" pairs or repeated "-Xcompiler" entries).
if
compile_flags
:
import
shlex
for
flag
in
compile_flags
:
# Split each string like a shell would, preserving quoted args
tokens
=
shlex
.
split
(
flag
)
if
isinstance
(
flag
,
str
)
else
[
str
(
flag
)]
...
...
@@ -172,9 +166,7 @@ def default_compile_options(compile_flags: list[str] | None = None) -> list[str]
return
options
def
get_ptx_from_source
(
code
:
str
,
compile_flags
:
list
[
str
]
|
None
=
None
,
verbose
:
bool
=
False
)
->
str
:
def
get_ptx_from_source
(
code
:
str
,
compile_flags
:
list
[
str
]
|
None
=
None
,
verbose
:
bool
=
False
)
->
str
:
"""
Compile CUDA C++ source to PTX using NVCC and return as text.
...
...
@@ -212,9 +204,7 @@ def _find_tool(name: str) -> str | None:
return
None
def
get_sass_from_source
(
code
:
str
,
compile_flags
:
list
[
str
]
|
None
=
None
,
verbose
:
bool
=
False
)
->
str
:
def
get_sass_from_source
(
code
:
str
,
compile_flags
:
list
[
str
]
|
None
=
None
,
verbose
:
bool
=
False
)
->
str
:
"""
Compile CUDA C++ source to CUBIN and disassemble to SASS.
...
...
@@ -246,9 +236,7 @@ def get_sass_from_source(code: str,
cand_nvdisasm
=
_find_tool
(
"nvdisasm"
)
cand_cuobjdump
=
_find_tool
(
"cuobjdump"
)
if
not
cand_nvdisasm
and
not
cand_cuobjdump
:
raise
RuntimeError
(
"Cannot find 'nvdisasm' or 'cuobjdump'. Please ensure CUDA toolkit is installed and in PATH."
)
raise
RuntimeError
(
"Cannot find 'nvdisasm' or 'cuobjdump'. Please ensure CUDA toolkit is installed and in PATH."
)
last_err
:
str
|
None
=
None
try
:
# Attempt nvdisasm first
...
...
@@ -268,8 +256,7 @@ def get_sass_from_source(code: str,
return
text
last_err
=
f
"
{
tool_name
}
rc=
{
proc
.
returncode
}
, output:
\n
{
text
}
"
# If we reach here, all attempts failed
raise
RuntimeError
(
f
"SASS disassembly failed. Tried tools: "
f
"
{
', '
.
join
(
name
for
name
,
_
in
tools_to_try
)
}
\n
{
last_err
or
''
}
"
)
raise
RuntimeError
(
f
"SASS disassembly failed. Tried tools:
{
', '
.
join
(
name
for
name
,
_
in
tools_to_try
)
}
\n
{
last_err
or
''
}
"
)
finally
:
with
contextlib
.
suppress
(
Exception
):
os
.
remove
(
cubin_path
)
...
...
@@ -438,8 +425,7 @@ def get_target_compute_version(target=None):
if
tvm
.
cuda
(
0
).
exist
:
return
tvm
.
cuda
(
0
).
compute_version
raise
ValueError
(
"No CUDA architecture was specified or GPU detected."
"Try specifying it by adding '-arch=sm_xx' to your target."
)
raise
ValueError
(
"No CUDA architecture was specified or GPU detected.Try specifying it by adding '-arch=sm_xx' to your target."
)
def
parse_compute_version
(
compute_version
)
->
tuple
[
int
,
int
]:
...
...
@@ -524,7 +510,8 @@ def have_tensorcore(compute_version=None, target=None):
warnings
.
warn
(
"Tensorcore will be disabled due to no CUDA architecture specified."
"Try specifying it by adding '-arch=sm_xx' to your target."
,
stacklevel
=
2
)
stacklevel
=
2
,
)
return
False
compute_version
=
target
.
attrs
[
"arch"
]
# Compute version will be in the form "sm_{major}{minor}"
...
...
tilelang/contrib/nvrtc.py
View file @
29051439
...
...
@@ -11,11 +11,13 @@ def get_nvrtc_version() -> tuple[int, int]:
return
(
major
,
minor
)
def
compile_cuda
(
code
:
str
,
def
compile_cuda
(
code
:
str
,
target_format
:
Literal
[
"ptx"
,
"cubin"
]
=
"ptx"
,
arch
:
int
|
None
=
None
,
options
:
str
|
list
[
str
]
|
None
=
None
,
verbose
:
bool
=
False
)
->
bytearray
:
verbose
:
bool
=
False
,
)
->
bytearray
:
"""Compile cuda code with NVRTC.
Parameters
...
...
@@ -43,8 +45,7 @@ def compile_cuda(code: str,
if
arch
is
None
:
# If None, then it will use `tvm.target.Target.current().arch`.
# Target arch could be a str like "80", "90", "90a", etc.
major
,
minor
=
parse_compute_version
(
get_target_compute_version
(
Target
.
current
(
allow_none
=
True
)))
major
,
minor
=
parse_compute_version
(
get_target_compute_version
(
Target
.
current
(
allow_none
=
True
)))
arch
=
major
*
10
+
minor
prefix
=
"compute"
if
target_format
==
"ptx"
else
"sm"
suffix
=
"a"
if
arch
>=
90
else
""
...
...
@@ -77,8 +78,7 @@ def compile_cuda(code: str,
compile_result
=
nvrtc
.
nvrtcCompileProgram
(
program
,
len
(
options_bytes
),
options_bytes
)[
0
]
if
compile_result
!=
nvrtc
.
nvrtcResult
.
NVRTC_SUCCESS
:
msg
=
f
"
{
code
}
\n
"
\
f
"Compilation error:
\n
"
msg
=
f
"
{
code
}
\n
Compilation error:
\n
"
if
verbose
:
result
,
log_size
=
nvrtc
.
nvrtcGetProgramLogSize
(
program
)
assert
result
==
nvrtc
.
nvrtcResult
.
NVRTC_SUCCESS
,
f
"Failed to get program log size:
{
result
}
"
...
...
@@ -105,7 +105,6 @@ def compile_cuda(code: str,
assert
result
==
nvrtc
.
nvrtcResult
.
NVRTC_SUCCESS
,
f
"Failed to get PTX:
{
result
}
"
# Destroy handler
assert
nvrtc
.
nvrtcDestroyProgram
(
program
)[
0
]
==
nvrtc
.
nvrtcResult
.
NVRTC_SUCCESS
,
f
"Failed to destroy program:
{
result
}
"
assert
nvrtc
.
nvrtcDestroyProgram
(
program
)[
0
]
==
nvrtc
.
nvrtcResult
.
NVRTC_SUCCESS
,
f
"Failed to destroy program:
{
result
}
"
return
result_bytes
tilelang/contrib/rocm.py
View file @
29051439
...
...
@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
"""Utility for ROCm backend"""
# ruff: noqa
import
re
import
subprocess
...
...
@@ -255,9 +256,11 @@ def get_rocm_arch(rocm_path="/opt/rocm"):
gpu_arch
=
match
.
group
(
1
)
return
gpu_arch
except
subprocess
.
CalledProcessError
:
print
(
f
"Unable to execute rocminfo command,
\
print
(
f
"Unable to execute rocminfo command,
\
please ensure ROCm is installed and you have an AMD GPU on your system.
\
using default
{
gpu_arch
}
."
)
using default
{
gpu_arch
}
."
)
return
gpu_arch
...
...
tilelang/engine/lower.py
View file @
29051439
"""The compiler for TL programs."""
from
__future__
import
annotations
import
os
...
...
@@ -28,14 +29,13 @@ def is_cpu_device_backend(target: Target):
def
has_device_kernel_launch
(
attrs
)
->
bool
:
"""Check if the attributes indicate a device kernel launch."""
return
bool
(
attrs
and
"calling_conv"
in
attrs
and
attrs
[
"calling_conv"
]
==
CallingConv
.
DEVICE_KERNEL_LAUNCH
)
return
bool
(
attrs
and
"calling_conv"
in
attrs
and
attrs
[
"calling_conv"
]
==
CallingConv
.
DEVICE_KERNEL_LAUNCH
)
def
is_device_call_c_device
(
func
:
tir
.
PrimFunc
):
attrs
=
func
.
attrs
calling_conv
=
attrs
.
get
(
"calling_conv"
,
CallingConv
.
DEFAULT
)
is_cpacked
=
(
calling_conv
==
CallingConv
.
C_PACKED_FUNC
)
is_cpacked
=
calling_conv
==
CallingConv
.
C_PACKED_FUNC
# Check if it's a C target
if
"target"
in
attrs
and
attrs
[
"target"
].
kind
.
name
==
"c"
and
not
is_cpacked
:
...
...
@@ -141,16 +141,16 @@ def extrac_params(func: tir.PrimFunc) -> list[KernelParam]:
if
var
in
func
.
buffer_map
:
tensor_types
.
append
(
KernelParam
.
from_buffer
(
func
.
buffer_map
[
var
]))
else
:
if
var
.
dtype
==
'
handle
'
:
if
var
.
dtype
==
"
handle
"
:
raise
ValueError
(
f
'Handle parameter
{
var
}
must be mapped to a buffer.
\n
'
f
'Please use T.tensor(
{
var
.
name
}
, shape=..., dtype=...) to map it to a buffer.'
)
f
"Handle parameter
{
var
}
must be mapped to a buffer.
\n
"
f
"Please use T.tensor(
{
var
.
name
}
, shape=..., dtype=...) to map it to a buffer."
)
tensor_types
.
append
(
KernelParam
.
from_var
(
var
))
return
tensor_types
def
canon_target_host
(
target
:
str
|
Target
,
target_host
:
str
|
Target
|
None
):
if
not
target_host
:
target_host
=
"llvm"
if
tvm
.
runtime
.
enabled
(
"llvm"
)
else
"c"
...
...
@@ -195,11 +195,9 @@ def device_codegen_without_compile(device_mod: tvm.IRModule, target: Target) ->
device_mod
=
tilelang
.
transform
.
LowerIntrin
()(
device_mod
)
device_mod
=
tir
.
transform
.
Simplify
()(
device_mod
)
if
target
.
kind
.
name
==
"cuda"
:
device_mod
=
tvm
.
ffi
.
get_global_func
(
"target.build.tilelang_cuda_without_compile"
)(
device_mod
,
target
)
device_mod
=
tvm
.
ffi
.
get_global_func
(
"target.build.tilelang_cuda_without_compile"
)(
device_mod
,
target
)
elif
target
.
kind
.
name
==
"hip"
:
device_mod
=
tvm
.
ffi
.
get_global_func
(
"target.build.tilelang_hip_without_compile"
)(
device_mod
,
target
)
device_mod
=
tvm
.
ffi
.
get_global_func
(
"target.build.tilelang_hip_without_compile"
)(
device_mod
,
target
)
elif
target
.
kind
.
name
==
"c"
:
device_mod
=
tvm
.
ffi
.
get_global_func
(
"target.build.tilelang_cpp"
)(
device_mod
,
target
)
elif
target
.
kind
.
name
==
"llvm"
:
...
...
@@ -222,12 +220,12 @@ def lower(
enable_host_codegen
=
False
,
enable_device_compile
=
False
,
)
->
CompiledArtifact
:
'''
"""
enable_host_codegen: whether to enable host codegen, default is False, as we have our
own host codegen implementation in jit.
enable_device_compile: whether to enable device codegen, default is False, as we have our
own device codegen implementation in jit.
'''
"""
mod
=
func_or_mod
params
=
None
...
...
@@ -259,14 +257,11 @@ def lower(
host_mod
=
tir
.
transform
.
Filter
(
_is_host_call
)(
mod
)
device_mod
=
tir
.
transform
.
Filter
(
_is_device_call
)(
mod
)
codegen_mod
=
device_codegen
(
device_mod
,
target
)
if
enable_device_compile
else
device_codegen_without_compile
(
device_mod
,
target
)
codegen_mod
=
device_codegen
(
device_mod
,
target
)
if
enable_device_compile
else
device_codegen_without_compile
(
device_mod
,
target
)
if
enable_host_codegen
:
host_mod
=
host_codegen
(
host_mod
,
target_host
)
host_mod
.
import_module
(
codegen_mod
)
return
CompiledArtifact
(
host_mod
,
device_mod
,
params
,
codegen_mod
.
inspect_source
(),
rt_mod
=
host_mod
)
return
CompiledArtifact
(
host_mod
,
device_mod
,
params
,
codegen_mod
.
inspect_source
(),
rt_mod
=
host_mod
)
return
CompiledArtifact
(
host_mod
,
device_mod
,
params
,
codegen_mod
.
inspect_source
())
tilelang/engine/param.py
View file @
29051439
"""The profiler and convert to torch utils"""
from
__future__
import
annotations
from
dataclasses
import
dataclass
...
...
@@ -14,6 +15,7 @@ class KernelParam:
Represents parameters for a kernel operation, storing dtype and shape information.
Used to describe tensor or scalar parameters in TVM/PyTorch interop.
"""
dtype
:
torch
.
dtype
# PyTorch data type of the parameter
shape
:
list
[
int
|
Var
]
# List of dimensions, can be integers or TVM variables
...
...
@@ -109,6 +111,7 @@ class CompiledArtifact:
Represents a compiled kernel artifact containing both host and device code.
Stores all necessary components for kernel execution in the TVM runtime.
"""
host_mod
:
tvm
.
IRModule
# Host-side TVM IR module for managing kernel execution
device_mod
:
tvm
.
IRModule
# Device-side TVM IR module containing the actual kernel code
params
:
list
[
KernelParam
]
# List of parameters (tensors/scalars) used by the kernel
...
...
tilelang/engine/phase.py
View file @
29051439
...
...
@@ -6,8 +6,7 @@ from tilelang.transform import PassContext
from
tilelang.contrib.nvcc
import
have_tma
,
is_hopper
def
allow_warp_specialized
(
pass_ctx
:
PassContext
|
None
=
None
,
target
:
Target
|
None
=
None
)
->
bool
:
def
allow_warp_specialized
(
pass_ctx
:
PassContext
|
None
=
None
,
target
:
Target
|
None
=
None
)
->
bool
:
# avoid circular import
from
tilelang.jit.adapter.utils
import
is_cuda_target
...
...
@@ -19,8 +18,7 @@ def allow_warp_specialized(pass_ctx: PassContext | None = None,
return
not
disable_warp_specialized
def
allow_tma_and_warp_specialized
(
pass_ctx
:
PassContext
|
None
=
None
,
target
:
Target
|
None
=
None
)
->
bool
:
def
allow_tma_and_warp_specialized
(
pass_ctx
:
PassContext
|
None
=
None
,
target
:
Target
|
None
=
None
)
->
bool
:
if
pass_ctx
is
None
:
pass_ctx
=
tilelang
.
transform
.
get_pass_context
()
if
not
have_tma
(
target
):
...
...
@@ -47,12 +45,10 @@ def allow_global_thread_synchronization(pass_ctx: PassContext | None = None) ->
return
enable_global_thread_sync
def
should_enable_aggressive_merge
(
pass_ctx
:
PassContext
|
None
=
None
,
target
:
Target
|
None
=
None
)
->
bool
:
def
should_enable_aggressive_merge
(
pass_ctx
:
PassContext
|
None
=
None
,
target
:
Target
|
None
=
None
)
->
bool
:
if
pass_ctx
is
None
:
pass_ctx
=
tilelang
.
transform
.
get_pass_context
()
enable_aggressive_merge
=
bool
(
pass_ctx
.
config
.
get
(
tilelang
.
PassConfigKey
.
TL_ENABLE_AGGRESSIVE_SHARED_MEMORY_MERGE
,
False
))
enable_aggressive_merge
=
bool
(
pass_ctx
.
config
.
get
(
tilelang
.
PassConfigKey
.
TL_ENABLE_AGGRESSIVE_SHARED_MEMORY_MERGE
,
False
))
if
allow_warp_specialized
(
pass_ctx
=
pass_ctx
,
target
=
target
):
# This is a workaround to avoid the bug in the MergeSharedMemoryAllocations pass
# when warp specialization is enabled, as different warp threads may access different
...
...
@@ -88,7 +84,7 @@ def get_layout_visual_formats(pass_ctx: PassContext | None = None) -> list[str]:
return
[
"txt"
,
"png"
,
"pdf"
,
"svg"
]
if
","
in
formats_str
:
formats_list
=
[
f
.
strip
()
for
f
in
formats_str
.
split
(
','
)]
formats_list
=
[
f
.
strip
()
for
f
in
formats_str
.
split
(
","
)]
else
:
formats_list
=
[
formats_str
]
...
...
@@ -257,9 +253,7 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
# MergeSharedMemoryAllocations must be applied after SplitHostDevice
# because the merged allocation site is at the beginning of each device function
enable_aggressive_merge
=
should_enable_aggressive_merge
(
pass_ctx
=
pass_ctx
,
target
=
target
)
mod
=
tilelang
.
transform
.
MergeSharedMemoryAllocations
(
enable_aggressive_merge
=
enable_aggressive_merge
)(
mod
)
mod
=
tilelang
.
transform
.
MergeSharedMemoryAllocations
(
enable_aggressive_merge
=
enable_aggressive_merge
)(
mod
)
mod
=
tilelang
.
transform
.
ThreadSync
(
"shared"
)(
mod
)
mod
=
tilelang
.
transform
.
ThreadSync
(
"shared.dyn"
)(
mod
)
# Inject PTX async copy must behind the thread sync pass
...
...
tilelang/env.py
View file @
29051439
...
...
@@ -10,36 +10,34 @@ from dataclasses import dataclass
logger
=
logging
.
getLogger
(
__name__
)
# SETUP ENVIRONMENT VARIABLES
CUTLASS_NOT_FOUND_MESSAGE
=
(
"CUTLASS is not installed or found in the expected path"
)
CUTLASS_NOT_FOUND_MESSAGE
=
"CUTLASS is not installed or found in the expected path"
", which may lead to compilation bugs when utilize tilelang backend."
COMPOSABLE_KERNEL_NOT_FOUND_MESSAGE
=
(
"Composable Kernel is not installed or found in the expected path"
)
COMPOSABLE_KERNEL_NOT_FOUND_MESSAGE
=
"Composable Kernel is not installed or found in the expected path"
", which may lead to compilation bugs when utilize tilelang backend."
TL_TEMPLATE_NOT_FOUND_MESSAGE
=
(
"TileLang is not installed or found in the expected path"
)
TL_TEMPLATE_NOT_FOUND_MESSAGE
=
"TileLang is not installed or found in the expected path"
", which may lead to compilation bugs when utilize tilelang backend."
TVM_LIBRARY_NOT_FOUND_MESSAGE
=
(
"TVM is not installed or found in the expected path"
)
TVM_LIBRARY_NOT_FOUND_MESSAGE
=
"TVM is not installed or found in the expected path"
TL_ROOT
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
# Only expose the internal lib directory to sys.path to avoid shadowing
# common top-level module names (e.g., utils, analysis) from user projects.
TL_LIBS
=
[
os
.
path
.
join
(
TL_ROOT
,
'
lib
'
)]
TL_LIBS
=
[
os
.
path
.
join
(
TL_ROOT
,
"
lib
"
)]
TL_LIBS
=
[
i
for
i
in
TL_LIBS
if
os
.
path
.
exists
(
i
)]
DEV
=
False
THIRD_PARTY_ROOT
=
os
.
path
.
join
(
TL_ROOT
,
'
3rdparty
'
)
THIRD_PARTY_ROOT
=
os
.
path
.
join
(
TL_ROOT
,
"
3rdparty
"
)
if
not
os
.
path
.
exists
(
THIRD_PARTY_ROOT
):
DEV
=
True
tl_dev_root
=
os
.
path
.
dirname
(
TL_ROOT
)
dev_lib_root
=
os
.
path
.
join
(
tl_dev_root
,
'
build
'
)
dev_lib_root
=
os
.
path
.
join
(
tl_dev_root
,
"
build
"
)
# In dev builds, place artifacts under build/lib and point search path there
# to avoid adding the entire build root to sys.path.
TL_LIBS
=
[
os
.
path
.
join
(
dev_lib_root
,
'
lib
'
),
os
.
path
.
join
(
dev_lib_root
,
'
tvm
'
)]
THIRD_PARTY_ROOT
=
os
.
path
.
join
(
tl_dev_root
,
'
3rdparty
'
)
logger
.
warning
(
f
'
Loading tilelang libs from dev root:
{
dev_lib_root
}
'
)
TL_LIBS
=
[
os
.
path
.
join
(
dev_lib_root
,
"
lib
"
),
os
.
path
.
join
(
dev_lib_root
,
"
tvm
"
)]
THIRD_PARTY_ROOT
=
os
.
path
.
join
(
tl_dev_root
,
"
3rdparty
"
)
logger
.
warning
(
f
"
Loading tilelang libs from dev root:
{
dev_lib_root
}
"
)
assert
TL_LIBS
and
all
(
os
.
path
.
exists
(
i
)
for
i
in
TL_LIBS
),
f
'tilelang lib root do not exists:
{
TL_LIBS
}
'
assert
TL_LIBS
and
all
(
os
.
path
.
exists
(
i
)
for
i
in
TL_LIBS
),
f
"tilelang lib root do not exists:
{
TL_LIBS
}
"
for
lib
in
TL_LIBS
:
if
lib
not
in
sys
.
path
:
...
...
@@ -52,7 +50,7 @@ def _find_cuda_home() -> str:
Adapted from https://github.com/pytorch/pytorch/blob/main/torch/utils/cpp_extension.py
"""
# Guess #1
cuda_home
=
os
.
environ
.
get
(
'
CUDA_HOME
'
)
or
os
.
environ
.
get
(
'
CUDA_PATH
'
)
cuda_home
=
os
.
environ
.
get
(
"
CUDA_HOME
"
)
or
os
.
environ
.
get
(
"
CUDA_PATH
"
)
if
cuda_home
is
None
:
# Guess #2
nvcc_path
=
shutil
.
which
(
"nvcc"
)
...
...
@@ -70,15 +68,15 @@ def _find_cuda_home() -> str:
else
:
# Guess #3
if
sys
.
platform
==
'
win32
'
:
cuda_homes
=
glob
.
glob
(
'
C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v*.*
'
)
cuda_home
=
''
if
len
(
cuda_homes
)
==
0
else
cuda_homes
[
0
]
if
sys
.
platform
==
"
win32
"
:
cuda_homes
=
glob
.
glob
(
"
C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v*.*
"
)
cuda_home
=
""
if
len
(
cuda_homes
)
==
0
else
cuda_homes
[
0
]
else
:
# Linux/macOS
if
os
.
path
.
exists
(
'
/usr/local/cuda
'
):
cuda_home
=
'
/usr/local/cuda
'
elif
os
.
path
.
exists
(
'
/opt/nvidia/hpc_sdk/Linux_x86_64
'
):
cuda_home
=
'
/opt/nvidia/hpc_sdk/Linux_x86_64
'
if
os
.
path
.
exists
(
"
/usr/local/cuda
"
):
cuda_home
=
"
/usr/local/cuda
"
elif
os
.
path
.
exists
(
"
/opt/nvidia/hpc_sdk/Linux_x86_64
"
):
cuda_home
=
"
/opt/nvidia/hpc_sdk/Linux_x86_64
"
# Validate found path
if
cuda_home
is
None
or
not
os
.
path
.
exists
(
cuda_home
):
...
...
@@ -89,13 +87,13 @@ def _find_cuda_home() -> str:
def
_find_rocm_home
()
->
str
:
"""Find the ROCM install path."""
rocm_home
=
os
.
environ
.
get
(
'
ROCM_PATH
'
)
or
os
.
environ
.
get
(
'
ROCM_HOME
'
)
rocm_home
=
os
.
environ
.
get
(
"
ROCM_PATH
"
)
or
os
.
environ
.
get
(
"
ROCM_HOME
"
)
if
rocm_home
is
None
:
rocmcc_path
=
shutil
.
which
(
"hipcc"
)
if
rocmcc_path
is
not
None
:
rocm_home
=
os
.
path
.
dirname
(
os
.
path
.
dirname
(
rocmcc_path
))
else
:
rocm_home
=
'
/opt/rocm
'
rocm_home
=
"
/opt/rocm
"
if
not
os
.
path
.
exists
(
rocm_home
):
rocm_home
=
None
return
rocm_home
if
rocm_home
is
not
None
else
""
...
...
@@ -104,6 +102,7 @@ def _find_rocm_home() -> str:
# Cache control
class
CacheState
:
"""Class to manage global kernel caching state."""
_enabled
=
True
@
classmethod
...
...
@@ -230,13 +229,11 @@ class Environment:
TILELANG_TMP_DIR
=
EnvVar
(
"TILELANG_TMP_DIR"
,
os
.
path
.
join
(
TILELANG_CACHE_DIR
.
get
(),
"tmp"
))
# Kernel Build options
TILELANG_PRINT_ON_COMPILATION
=
EnvVar
(
"TILELANG_PRINT_ON_COMPILATION"
,
"1"
)
# print kernel name on compile
TILELANG_PRINT_ON_COMPILATION
=
EnvVar
(
"TILELANG_PRINT_ON_COMPILATION"
,
"1"
)
# print kernel name on compile
TILELANG_DISABLE_CACHE
=
EnvVar
(
"TILELANG_DISABLE_CACHE"
,
"0"
)
# disable kernel cache, usually for unit testing / debugging, high priority
TILELANG_CLEAR_CACHE
=
EnvVar
(
"TILELANG_CLEAR_CACHE"
,
"0"
)
# DEPRECATED! clear cache automatically if set
"TILELANG_DISABLE_CACHE"
,
"0"
)
# disable kernel cache, usually for unit testing / debugging, high priority
TILELANG_CLEAR_CACHE
=
EnvVar
(
"TILELANG_CLEAR_CACHE"
,
"0"
)
# DEPRECATED! clear cache automatically if set
# Kernel selection options
# Default to GEMM v2; set to "1"/"true"/"yes"/"on" to force v1
...
...
@@ -244,12 +241,9 @@ class Environment:
# Auto-tuning settings
TILELANG_AUTO_TUNING_DISABLE_CACHE
=
EnvVar
(
"TILELANG_AUTO_TUNING_DISABLE_CACHE"
,
"0"
)
TILELANG_AUTO_TUNING_CPU_UTILITIES
=
EnvVar
(
"TILELANG_AUTO_TUNING_CPU_UTILITIES"
,
"0.9"
)
# percent of CPUs used
TILELANG_AUTO_TUNING_CPU_COUNTS
=
EnvVar
(
"TILELANG_AUTO_TUNING_CPU_COUNTS"
,
"-1"
)
# -1 means auto
TILELANG_AUTO_TUNING_MAX_CPU_COUNT
=
EnvVar
(
"TILELANG_AUTO_TUNING_MAX_CPU_COUNT"
,
"-1"
)
# -1 means no limit
TILELANG_AUTO_TUNING_CPU_UTILITIES
=
EnvVar
(
"TILELANG_AUTO_TUNING_CPU_UTILITIES"
,
"0.9"
)
# percent of CPUs used
TILELANG_AUTO_TUNING_CPU_COUNTS
=
EnvVar
(
"TILELANG_AUTO_TUNING_CPU_COUNTS"
,
"-1"
)
# -1 means auto
TILELANG_AUTO_TUNING_MAX_CPU_COUNT
=
EnvVar
(
"TILELANG_AUTO_TUNING_MAX_CPU_COUNT"
,
"-1"
)
# -1 means no limit
# TVM integration
SKIP_LOADING_TILELANG_SO
=
EnvVar
(
"SKIP_LOADING_TILELANG_SO"
,
"0"
)
...
...
@@ -323,18 +317,18 @@ def prepend_pythonpath(path):
if
env
.
TVM_IMPORT_PYTHON_PATH
is
not
None
:
prepend_pythonpath
(
env
.
TVM_IMPORT_PYTHON_PATH
)
else
:
tvm_path
=
os
.
path
.
join
(
THIRD_PARTY_ROOT
,
'
tvm
'
,
'
python
'
)
tvm_path
=
os
.
path
.
join
(
THIRD_PARTY_ROOT
,
"
tvm
"
,
"
python
"
)
assert
os
.
path
.
exists
(
tvm_path
),
tvm_path
if
tvm_path
not
in
sys
.
path
:
prepend_pythonpath
(
tvm_path
)
env
.
TVM_IMPORT_PYTHON_PATH
=
tvm_path
# By default, the built TVM-related libraries are stored in TL_LIBS.
if
os
.
environ
.
get
(
"TVM_LIBRARY_PATH"
)
is
None
:
os
.
environ
[
'
TVM_LIBRARY_PATH
'
]
=
env
.
TVM_LIBRARY_PATH
=
os
.
pathsep
.
join
(
TL_LIBS
)
os
.
environ
[
"
TVM_LIBRARY_PATH
"
]
=
env
.
TVM_LIBRARY_PATH
=
os
.
pathsep
.
join
(
TL_LIBS
)
# Initialize CUTLASS paths
if
os
.
environ
.
get
(
"TL_CUTLASS_PATH"
,
None
)
is
None
:
cutlass_inc_path
=
os
.
path
.
join
(
THIRD_PARTY_ROOT
,
'
cutlass
'
,
'
include
'
)
cutlass_inc_path
=
os
.
path
.
join
(
THIRD_PARTY_ROOT
,
"
cutlass
"
,
"
include
"
)
if
os
.
path
.
exists
(
cutlass_inc_path
):
os
.
environ
[
"TL_CUTLASS_PATH"
]
=
env
.
CUTLASS_INCLUDE_DIR
=
cutlass_inc_path
else
:
...
...
@@ -342,7 +336,7 @@ if os.environ.get("TL_CUTLASS_PATH", None) is None:
# Initialize COMPOSABLE_KERNEL paths
if
os
.
environ
.
get
(
"TL_COMPOSABLE_KERNEL_PATH"
,
None
)
is
None
:
ck_inc_path
=
os
.
path
.
join
(
THIRD_PARTY_ROOT
,
'
composable_kernel
'
,
'
include
'
)
ck_inc_path
=
os
.
path
.
join
(
THIRD_PARTY_ROOT
,
"
composable_kernel
"
,
"
include
"
)
if
os
.
path
.
exists
(
ck_inc_path
):
os
.
environ
[
"TL_COMPOSABLE_KERNEL_PATH"
]
=
env
.
COMPOSABLE_KERNEL_INCLUDE_DIR
=
ck_inc_path
else
:
...
...
tilelang/intrinsics/mfma_layout.py
View file @
29051439
...
...
@@ -4,7 +4,7 @@ import tilelang.language as T
def
shared_16x4_to_local_64x1_layout_A
(
i
,
j
):
thread_id
=
(
j
*
16
+
i
)
thread_id
=
j
*
16
+
i
return
thread_id
,
convert
(
0
)
...
...
@@ -15,7 +15,7 @@ def thread_id_shared_access_64x1_to_16x4_layout_A(thread_id, local_id):
def
shared_4x16_to_local_64x1_layout_B
(
i
,
j
):
thread_id
=
(
i
*
16
+
j
)
thread_id
=
i
*
16
+
j
return
thread_id
,
convert
(
0
)
...
...
@@ -27,7 +27,7 @@ def thread_id_shared_access_64x1_to_4x16_layout_B(thread_id, local_id):
def
shared_16x16_to_local_64x4_layout_C
(
i
,
j
):
thread_id
=
j
+
(
i
//
4
)
*
16
local
=
(
i
%
4
)
local
=
i
%
4
return
thread_id
,
local
...
...
@@ -45,7 +45,7 @@ def thread_id_shared_access_64x4_to_16x16_layout_A(thread_id, local_id):
def
shared_16x16_to_local_64x4_layout_A
(
i
,
j
):
thread_id
=
i
+
16
*
(
j
//
4
)
local
=
(
j
%
4
)
local
=
j
%
4
return
thread_id
,
local
...
...
@@ -57,7 +57,7 @@ def thread_id_shared_access_64x4_to_16x16_layout_B(thread_id, local_id):
def
shared_16x16_to_local_64x4_layout_B
(
i
,
j
):
thread_id
=
j
+
(
i
//
4
)
*
16
local
=
(
i
%
4
)
local
=
i
%
4
return
thread_id
,
local
...
...
@@ -87,7 +87,7 @@ def thread_id_shared_access_64x8_to_16x32_layout_A(thread_id, local_id):
def
shared_16x32_to_local_64x8_layout_A
(
i
,
j
):
thread_id
=
i
+
16
*
(
j
//
8
)
local
=
(
j
%
8
)
local
=
j
%
8
return
thread_id
,
local
...
...
@@ -99,7 +99,7 @@ def thread_id_shared_access_64x8_to_16x32_layout_B(thread_id, local_id):
def
shared_16x32_to_local_64x8_layout_B
(
i
,
j
):
thread_id
=
j
+
(
i
//
8
)
*
16
local
=
(
i
%
8
)
local
=
i
%
8
return
thread_id
,
local
...
...
@@ -111,7 +111,7 @@ def thread_id_shared_access_64x16_to_16x64_layout_A(thread_id, local_id):
def
shared_16x64_to_local_64x16_layout_A
(
i
,
j
):
thread_id
=
i
+
16
*
(
j
//
16
)
local
=
(
j
%
16
)
local
=
j
%
16
return
thread_id
,
local
...
...
@@ -123,7 +123,7 @@ def thread_id_shared_access_64x16_to_16x64_layout_B(thread_id, local_id):
def
shared_16x64_to_local_64x16_layout_B
(
i
,
j
):
thread_id
=
i
+
16
*
(
j
//
16
)
local
=
(
j
%
16
)
local
=
j
%
16
return
thread_id
,
local
...
...
tilelang/intrinsics/mfma_macro_generator.py
View file @
29051439
...
...
@@ -6,7 +6,7 @@ from tvm import tir
from
tvm.ir
import
Range
from
tvm.tir
import
PrimExpr
,
IndexMap
,
Buffer
,
Var
,
BufferRegion
,
BufferLoad
from
tvm.runtime
import
convert
from
.utils
import
(
mfma_store_index_map
)
from
.utils
import
mfma_store_index_map
from
typing
import
Literal
,
Callable
from
tilelang.utils
import
is_fragment
...
...
@@ -101,7 +101,7 @@ class MatrixCoreIntrinEmitter:
self
.
warp_rows
=
warp_row_tiles
//
self
.
micro_size_x
self
.
warp_cols
=
warp_col_tiles
//
self
.
micro_size_y
self
.
reduce_k
=
reduce_k
self
.
threads
=
(
self
.
WARP_SIZE
*
(
block_row_warps
*
block_col_warps
)
*
reduce_k
)
self
.
threads
=
self
.
WARP_SIZE
*
(
block_row_warps
*
block_col_warps
)
*
reduce_k
self
.
num_elems_per_byte
=
num_elems_per_byte
self
.
thread_var
=
thread_var
...
...
@@ -132,12 +132,7 @@ class MatrixCoreIntrinEmitter:
def
_initialize_mfma_prefix
(
self
,
k_dim
=
16
):
in_dtype
,
out_dtype
=
self
.
a_dtype
,
self
.
accum_dtype
M_DIM
,
N_DIM
=
self
.
M_DIM
,
self
.
N_DIM
out_dtype_abbrv
=
{
"float16"
:
"f16"
,
"float32"
:
"f32"
,
"int8"
:
"i8"
,
"int32"
:
"i32"
}[
out_dtype
]
out_dtype_abbrv
=
{
"float16"
:
"f16"
,
"float32"
:
"f32"
,
"int8"
:
"i8"
,
"int32"
:
"i32"
}[
out_dtype
]
in_dtype_abbrv
=
{
"bfloat16"
:
"bf16"
,
...
...
@@ -176,7 +171,6 @@ class MatrixCoreIntrinEmitter:
self
.
b_preshuffle
=
b_preshuffle
def
get_ldmatrix_index_map
(
self
,
is_b
=
False
):
k_dim
=
self
.
k_dim
*
self
.
k_pack
transposed
=
self
.
a_transposed
if
not
is_b
else
self
.
b_transposed
if
k_dim
==
4
:
...
...
@@ -184,28 +178,42 @@ class MatrixCoreIntrinEmitter:
reverse_index_map
=
thread_id_shared_access_64x1_to_16x4_layout_A
if
is_b
:
index_map
=
shared_16x4_to_local_64x1_layout_A
if
transposed
else
shared_4x16_to_local_64x1_layout_B
reverse_index_map
=
thread_id_shared_access_64x1_to_16x4_layout_A
if
transposed
else
thread_id_shared_access_64x1_to_4x16_layout_B
reverse_index_map
=
(
thread_id_shared_access_64x1_to_16x4_layout_A
if
transposed
else
thread_id_shared_access_64x1_to_4x16_layout_B
)
elif
k_dim
==
16
:
index_map
=
shared_16x16_to_local_64x4_layout_B
if
transposed
else
shared_16x16_to_local_64x4_layout_A
reverse_index_map
=
thread_id_shared_access_64x4_to_16x16_layout_B
if
transposed
else
thread_id_shared_access_64x4_to_16x16_layout_A
reverse_index_map
=
(
thread_id_shared_access_64x4_to_16x16_layout_B
if
transposed
else
thread_id_shared_access_64x4_to_16x16_layout_A
)
if
is_b
:
index_map
=
shared_16x16_to_local_64x4_layout_A
if
transposed
else
shared_16x16_to_local_64x4_layout_B
reverse_index_map
=
thread_id_shared_access_64x4_to_16x16_layout_A
if
transposed
else
thread_id_shared_access_64x4_to_16x16_layout_B
reverse_index_map
=
(
thread_id_shared_access_64x4_to_16x16_layout_A
if
transposed
else
thread_id_shared_access_64x4_to_16x16_layout_B
)
elif
k_dim
==
32
:
index_map
=
shared_16x32_to_local_64x8_layout_B
if
transposed
else
shared_16x32_to_local_64x8_layout_A
reverse_index_map
=
thread_id_shared_access_64x8_to_16x32_layout_B
if
transposed
else
thread_id_shared_access_64x8_to_16x32_layout_A
reverse_index_map
=
(
thread_id_shared_access_64x8_to_16x32_layout_B
if
transposed
else
thread_id_shared_access_64x8_to_16x32_layout_A
)
if
is_b
:
index_map
=
shared_16x32_to_local_64x8_layout_A
if
transposed
else
shared_16x32_to_local_64x8_layout_B
reverse_index_map
=
thread_id_shared_access_64x8_to_16x32_layout_A
if
transposed
else
thread_id_shared_access_64x8_to_16x32_layout_B
reverse_index_map
=
(
thread_id_shared_access_64x8_to_16x32_layout_A
if
transposed
else
thread_id_shared_access_64x8_to_16x32_layout_B
)
elif
k_dim
==
64
:
index_map
=
shared_16x64_to_local_64x16_layout_B
if
transposed
else
shared_16x64_to_local_64x16_layout_A
reverse_index_map
=
thread_id_shared_access_64x16_to_16x64_layout_B
if
transposed
else
thread_id_shared_access_64x16_to_16x64_layout_A
reverse_index_map
=
(
thread_id_shared_access_64x16_to_16x64_layout_B
if
transposed
else
thread_id_shared_access_64x16_to_16x64_layout_A
)
if
is_b
:
index_map
=
shared_16x64_to_local_64x16_layout_A
if
transposed
else
shared_16x64_to_local_64x16_layout_B
reverse_index_map
=
thread_id_shared_access_64x16_to_16x64_layout_A
if
transposed
else
thread_id_shared_access_64x16_to_16x64_layout_B
reverse_index_map
=
(
thread_id_shared_access_64x16_to_16x64_layout_A
if
transposed
else
thread_id_shared_access_64x16_to_16x64_layout_B
)
else
:
raise
ValueError
(
"k_dim must be 4 or 16 or 32 or 64 currently"
)
...
...
@@ -227,14 +235,12 @@ class MatrixCoreIntrinEmitter:
else
:
return
self
.
thread_var
def
extract_thread_binding
(
self
,
thread_id
,
is_m_first
=
None
)
->
tuple
[
PrimExpr
,
PrimExpr
,
PrimExpr
]:
'''
def
extract_thread_binding
(
self
,
thread_id
,
is_m_first
=
None
)
->
tuple
[
PrimExpr
,
PrimExpr
,
PrimExpr
]:
"""
is_m_first: True if the thread binding is in the form of (tx, warp_n, warp_m)
which represents [warp_size, block_row_warps (split n), block_col_warps (split m)]
Otherwise, it is in the form of [warp_size, block_col_warps (split m), block_row_warps (split n)]
'''
"""
WARP_SIZE
=
self
.
WARP_SIZE
block_row_warps
=
self
.
block_row_warps
block_col_warps
=
self
.
block_col_warps
...
...
@@ -244,16 +250,18 @@ class MatrixCoreIntrinEmitter:
is_m_first
=
self
.
is_m_first
if
is_m_first
:
lane_id
,
warp_n
,
warp_m
=
thread_id
%
WARP_SIZE
,
(
thread_id
//
WARP_SIZE
)
%
block_col_warps
,
(
thread_id
//
(
WARP_SIZE
*
block_col_warps
))
%
block_row_warps
,
lane_id
,
warp_n
,
warp_m
=
(
thread_id
%
WARP_SIZE
,
(
thread_id
//
WARP_SIZE
)
%
block_col_warps
,
(
thread_id
//
(
WARP_SIZE
*
block_col_warps
))
%
block_row_warps
,
)
return
lane_id
,
warp_n
,
warp_m
else
:
lane_id
,
warp_m
,
warp_n
=
thread_id
%
WARP_SIZE
,
(
thread_id
//
WARP_SIZE
)
%
block_row_warps
,
(
thread_id
//
(
WARP_SIZE
*
block_row_warps
))
%
block_col_warps
,
lane_id
,
warp_m
,
warp_n
=
(
thread_id
%
WARP_SIZE
,
(
thread_id
//
WARP_SIZE
)
%
block_row_warps
,
(
thread_id
//
(
WARP_SIZE
*
block_row_warps
))
%
block_col_warps
,
)
return
lane_id
,
warp_n
,
warp_m
def
ldmatrix_a
(
self
,
A_local_buf
,
A_shared_buf
:
Buffer
|
BufferRegion
,
ki
,
rk
=
0
):
...
...
@@ -287,18 +295,14 @@ class MatrixCoreIntrinEmitter:
for
i
in
T
.
serial
(
warp_rows
):
for
local_id
in
T
.
vectorized
(
k_pack
*
local_size_a
):
row
,
col
=
T
.
meta_var
(
reverse_index_map
(
tx
,
local_id
))
l
,
r
=
(
rk
*
chunk
+
ki
*
(
k_pack
*
micro_size_k
),
warp_m
*
warp_row_tiles
+
i
*
micro_size_x
)
A_local_buf
[
i
*
k_pack
*
local_size_a
+
local_id
]
=
A_buf
[
A_base0
+
l
+
row
,
A_base1
+
r
+
col
]
l
,
r
=
(
rk
*
chunk
+
ki
*
(
k_pack
*
micro_size_k
),
warp_m
*
warp_row_tiles
+
i
*
micro_size_x
)
A_local_buf
[
i
*
k_pack
*
local_size_a
+
local_id
]
=
A_buf
[
A_base0
+
l
+
row
,
A_base1
+
r
+
col
]
else
:
for
i
in
T
.
serial
(
warp_rows
):
for
local_id
in
T
.
vectorized
(
k_pack
*
local_size_a
):
row
,
col
=
T
.
meta_var
(
reverse_index_map
(
tx
,
local_id
))
l
,
r
=
(
warp_m
*
warp_row_tiles
+
i
*
micro_size_x
,
rk
*
chunk
+
ki
*
(
k_pack
*
micro_size_k
))
A_local_buf
[
i
*
k_pack
*
local_size_a
+
local_id
]
=
A_buf
[
A_base0
+
l
+
row
,
A_base1
+
r
+
col
]
l
,
r
=
(
warp_m
*
warp_row_tiles
+
i
*
micro_size_x
,
rk
*
chunk
+
ki
*
(
k_pack
*
micro_size_k
))
A_local_buf
[
i
*
k_pack
*
local_size_a
+
local_id
]
=
A_buf
[
A_base0
+
l
+
row
,
A_base1
+
r
+
col
]
return
_warp_ldmatrix_a
(
A_local_buf
,
A_shared_buf
,
ki
,
thread_binding
,
rk
)
...
...
@@ -337,8 +341,7 @@ class MatrixCoreIntrinEmitter:
warp_n
*
warp_col_tiles
+
j
*
micro_size_y
,
rk
*
chunk
+
ki
*
(
k_pack
*
micro_size_k
),
)
B_local_buf
[
j
*
k_pack
*
local_size_b
+
local_id
]
=
B_buf
[
B_base0
+
l
+
row
,
B_base1
+
r
+
col
]
B_local_buf
[
j
*
k_pack
*
local_size_b
+
local_id
]
=
B_buf
[
B_base0
+
l
+
row
,
B_base1
+
r
+
col
]
else
:
for
j
in
T
.
serial
(
warp_cols
):
...
...
@@ -348,16 +351,11 @@ class MatrixCoreIntrinEmitter:
rk
*
chunk
+
ki
*
(
k_pack
*
micro_size_k
),
warp_n
*
warp_col_tiles
+
j
*
micro_size_y
,
)
B_local_buf
[
j
*
k_pack
*
local_size_b
+
local_id
]
=
B_buf
[
B_base0
+
l
+
row
,
B_base1
+
r
+
col
]
B_local_buf
[
j
*
k_pack
*
local_size_b
+
local_id
]
=
B_buf
[
B_base0
+
l
+
row
,
B_base1
+
r
+
col
]
return
_warp_ldmatrix_b
(
B_local_buf
,
B_shared_buf
,
ki
,
thread_binding
,
rk
)
def
mfma
(
self
,
A_local_buf
:
Buffer
,
B_local_buf
:
Buffer
,
C_local_buf
:
Buffer
,
k_inner
:
PrimExpr
|
None
=
0
):
def
mfma
(
self
,
A_local_buf
:
Buffer
,
B_local_buf
:
Buffer
,
C_local_buf
:
Buffer
,
k_inner
:
PrimExpr
|
None
=
0
):
warp_rows
=
self
.
warp_rows
warp_cols
=
self
.
warp_cols
local_size_a
=
self
.
local_size_a
...
...
@@ -421,14 +419,13 @@ class MatrixCoreIntrinEmitter:
for
local_id
in
T
.
vectorized
(
local_size_out
):
row
,
col
=
T
.
meta_var
(
mfma_store_index_map
(
tx
,
local_id
))
if
C_buf_dims
==
2
:
C_buf
[(
warp_m
*
warp_rows
+
i
)
*
M_DIM
+
row
,
(
warp_n
*
warp_cols
+
j
)
*
N_DIM
+
col
]
=
C_local_buf
[
i
*
(
warp_cols
*
local_size_out
)
+
j
*
local_size_out
+
local_id
]
C_buf
[(
warp_m
*
warp_rows
+
i
)
*
M_DIM
+
row
,
(
warp_n
*
warp_cols
+
j
)
*
N_DIM
+
col
]
=
C_local_buf
[
i
*
(
warp_cols
*
local_size_out
)
+
j
*
local_size_out
+
local_id
]
else
:
C_buf
[
warp_m
*
warp_rows
+
i
,
warp_n
*
warp_cols
+
j
,
row
,
col
]
=
C_local_buf
[
i
*
warp_cols
*
local_size_out
+
j
*
local_size_out
+
local_id
]
C_buf
[
warp_m
*
warp_rows
+
i
,
warp_n
*
warp_cols
+
j
,
row
,
col
]
=
C_local_buf
[
i
*
warp_cols
*
local_size_out
+
j
*
local_size_out
+
local_id
]
@
T
.
macro
def
_warp_stmatrix_global
(
C_local_buf
,
C_buf
,
thread_binding
):
...
...
@@ -436,18 +433,17 @@ class MatrixCoreIntrinEmitter:
for
i
,
j
in
T
.
grid
(
warp_rows
,
warp_cols
):
for
local_id
in
T
.
vectorized
(
local_size_out
):
row
,
col
=
T
.
meta_var
(
mfma_store_index_map
(
tx
,
local_id
))
C_buf
[(
pid_m
*
BLOCK_M
+
warp_m
*
warp_rows
+
i
)
*
M_DIM
+
row
,
(
pid_n
*
BLOCK_N
+
warp_n
*
warp_cols
+
j
)
*
N_DIM
+
col
]
=
C_local_buf
[
i
*
warp_cols
*
local_size_out
+
j
*
local_size_out
+
local_id
]
return
_warp_stmatrix_global
(
C_local_buf
,
C_buf
,
thread_binding
)
if
is_global
else
_warp_stmatrix_shared
(
C_local_buf
,
C_buf
,
thread_binding
)
def
make_mfma_load_layout
(
self
,
local_buf
:
Buffer
,
matrix
:
Literal
[
"A"
,
"B"
]
=
"A"
)
->
T
.
Fragment
:
C_buf
[
(
pid_m
*
BLOCK_M
+
warp_m
*
warp_rows
+
i
)
*
M_DIM
+
row
,
(
pid_n
*
BLOCK_N
+
warp_n
*
warp_cols
+
j
)
*
N_DIM
+
col
]
=
C_local_buf
[
i
*
warp_cols
*
local_size_out
+
j
*
local_size_out
+
local_id
]
return
(
_warp_stmatrix_global
(
C_local_buf
,
C_buf
,
thread_binding
)
if
is_global
else
_warp_stmatrix_shared
(
C_local_buf
,
C_buf
,
thread_binding
)
)
def
make_mfma_load_layout
(
self
,
local_buf
:
Buffer
,
matrix
:
Literal
[
"A"
,
"B"
]
=
"A"
)
->
T
.
Fragment
:
"""
Create a layout function for storing MFMA results into a fragment buffer.
...
...
@@ -468,6 +464,7 @@ class MatrixCoreIntrinEmitter:
If `local_buf` is not detected to be a fragment buffer.
"""
from
tilelang.utils
import
is_fragment
assert
matrix
in
[
"A"
,
"B"
],
"matrix should be either A or B"
matrix_is_a
:
bool
=
matrix
==
"A"
matrix_is_b
:
bool
=
matrix
==
"B"
...
...
@@ -506,11 +503,9 @@ class MatrixCoreIntrinEmitter:
transform_func
:
Callable
=
None
if
matrix_is_a
:
transform_func
=
transform_func_sr_a
if
is_sr_axis_order
else
lambda
i
,
j
:
transform_func_sr_a
(
j
,
i
)
transform_func
=
transform_func_sr_a
if
is_sr_axis_order
else
lambda
i
,
j
:
transform_func_sr_a
(
j
,
i
)
elif
matrix_is_b
:
transform_func
=
transform_func_sr_b
if
is_sr_axis_order
else
lambda
i
,
j
:
transform_func_sr_b
(
j
,
i
)
transform_func
=
transform_func_sr_b
if
is_sr_axis_order
else
lambda
i
,
j
:
transform_func_sr_b
(
j
,
i
)
else
:
raise
ValueError
(
f
"Unsupported matrix
{
matrix
}
"
)
...
...
@@ -543,8 +538,7 @@ class MatrixCoreIntrinEmitter:
return
local_id
base_fragment
=
T
.
Fragment
(
[
micro_size_s
,
micro_size_r
*
self
.
k_pack
]
if
is_sr_axis_order
else
[
micro_size_r
*
self
.
k_pack
,
micro_size_s
],
[
micro_size_s
,
micro_size_r
*
self
.
k_pack
]
if
is_sr_axis_order
else
[
micro_size_r
*
self
.
k_pack
,
micro_size_s
],
forward_thread_fn
=
forward_thread
,
forward_index_fn
=
forward_index
,
)
...
...
@@ -558,31 +552,19 @@ class MatrixCoreIntrinEmitter:
replicate
=
block_col_warps
if
matrix_is_a
else
block_row_warps
if
is_sr_axis_order
:
warp_fragment
=
base_fragment
.
repeat
([
warp_s
,
warp_r
],
repeat_on_thread
=
False
,
lower_dim_first
=
False
)
warp_fragment
=
base_fragment
.
repeat
([
warp_s
,
warp_r
],
repeat_on_thread
=
False
,
lower_dim_first
=
False
)
if
matrix_is_a
:
block_fragment
=
warp_fragment
.
repeat
([
block_s
,
1
],
repeat_on_thread
=
True
,
lower_dim_first
=
True
).
replicate
(
replicate
)
block_fragment
=
warp_fragment
.
repeat
([
block_s
,
1
],
repeat_on_thread
=
True
,
lower_dim_first
=
True
).
replicate
(
replicate
)
elif
matrix_is_b
:
block_fragment
=
warp_fragment
.
replicate
(
replicate
).
repeat
([
block_s
,
1
],
repeat_on_thread
=
True
,
lower_dim_first
=
True
)
block_fragment
=
warp_fragment
.
replicate
(
replicate
).
repeat
([
block_s
,
1
],
repeat_on_thread
=
True
,
lower_dim_first
=
True
)
else
:
raise
ValueError
(
f
"Unsupported matrix type
{
matrix
}
"
)
else
:
warp_fragment
=
base_fragment
.
repeat
([
warp_r
,
warp_s
],
repeat_on_thread
=
False
,
lower_dim_first
=
True
)
warp_fragment
=
base_fragment
.
repeat
([
warp_r
,
warp_s
],
repeat_on_thread
=
False
,
lower_dim_first
=
True
)
if
matrix_is_a
:
block_fragment
=
warp_fragment
.
repeat
([
1
,
block_s
],
repeat_on_thread
=
True
,
lower_dim_first
=
True
).
replicate
(
replicate
)
block_fragment
=
warp_fragment
.
repeat
([
1
,
block_s
],
repeat_on_thread
=
True
,
lower_dim_first
=
True
).
replicate
(
replicate
)
elif
matrix_is_b
:
block_fragment
=
warp_fragment
.
replicate
(
replicate
).
repeat
([
1
,
block_s
],
repeat_on_thread
=
True
,
lower_dim_first
=
True
)
block_fragment
=
warp_fragment
.
replicate
(
replicate
).
repeat
([
1
,
block_s
],
repeat_on_thread
=
True
,
lower_dim_first
=
True
)
else
:
raise
ValueError
(
f
"Unsupported matrix type
{
matrix
}
"
)
...
...
@@ -686,7 +668,6 @@ class MatrixCoreIntrinEmitter:
class
MatrixCorePreshuffleIntrinEmitter
(
MatrixCoreIntrinEmitter
):
def
__init__
(
self
,
a_dtype
:
str
=
"float16"
,
...
...
@@ -792,20 +773,20 @@ class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter):
rk
*
(
chunk
//
micro_size_k
)
+
ki
,
warp_m
*
warp_rows
+
i
,
)
A_local_buf
[
i
*
k_pack
*
local_size_a
+
local_id
]
=
A_shared_buf
[
l
,
r
,
row
,
col
]
A_local_buf
[
i
*
k_pack
*
local_size_a
+
local_id
]
=
A_shared_buf
[
l
,
r
,
row
,
col
]
else
:
print
(
self
.
a_preshuffle
)
for
i
in
T
.
serial
(
warp_rows
):
for
local_id
in
T
.
vectorized
(
k_pack
*
local_size_a
):
row
,
col
=
T
.
meta_var
(
reverse_index_map
(
tx
,
local_id
))
l
,
r
=
(
warp_m
*
warp_rows
+
i
,
rk
*
(
chunk
//
micro_size_k
)
+
ki
)
A_local_buf
[
i
*
k_pack
*
local_size_a
+
local_id
]
=
A_shared_buf
[
l
,
r
,
row
,
col
]
A_local_buf
[
i
*
k_pack
*
local_size_a
+
local_id
]
=
A_shared_buf
[
l
,
r
,
row
,
col
]
return
_warp_ldmatrix_a_global
(
A_local_buf
,
A_buf
,
ki
,
thread_binding
,
rk
)
if
is_global
else
_warp_ldmatrix_a_shared
(
A_local_buf
,
A_buf
,
ki
,
thread_binding
,
rk
)
return
(
_warp_ldmatrix_a_global
(
A_local_buf
,
A_buf
,
ki
,
thread_binding
,
rk
)
if
is_global
else
_warp_ldmatrix_a_shared
(
A_local_buf
,
A_buf
,
ki
,
thread_binding
,
rk
)
)
def
ldmatrix_b
(
self
,
B_local_buf
,
B_buf
,
ki
,
rk
=
0
,
pid_m
=
None
,
pid_n
=
None
):
warp_cols
=
self
.
warp_cols
...
...
@@ -867,8 +848,7 @@ class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter):
warp_n
*
warp_cols
+
j
,
rk
*
(
chunk
//
micro_size_k
)
+
ki
,
)
B_local_buf
[
j
*
k_pack
*
local_size_b
+
local_id
]
=
B_shared_buf
[
l
,
r
,
row
,
col
]
B_local_buf
[
j
*
k_pack
*
local_size_b
+
local_id
]
=
B_shared_buf
[
l
,
r
,
row
,
col
]
else
:
for
j
in
T
.
serial
(
warp_cols
):
for
local_id
in
T
.
vectorized
(
k_pack
*
local_size_b
):
...
...
@@ -877,9 +857,10 @@ class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter):
rk
*
(
chunk
//
micro_size_k
)
+
ki
,
warp_n
*
warp_cols
+
j
,
)
B_local_buf
[
j
*
k_pack
*
local_size_b
+
local_id
]
=
B_shared_buf
[
l
,
r
,
row
,
col
]
B_local_buf
[
j
*
k_pack
*
local_size_b
+
local_id
]
=
B_shared_buf
[
l
,
r
,
row
,
col
]
return
_warp_ldmatrix_b_global
(
B_local_buf
,
B_buf
,
ki
,
thread_binding
,
rk
)
if
is_global
else
_warp_ldmatrix_b_shared
(
B_local_buf
,
B_buf
,
ki
,
thread_binding
,
rk
)
return
(
_warp_ldmatrix_b_global
(
B_local_buf
,
B_buf
,
ki
,
thread_binding
,
rk
)
if
is_global
else
_warp_ldmatrix_b_shared
(
B_local_buf
,
B_buf
,
ki
,
thread_binding
,
rk
)
)
tilelang/intrinsics/mma_layout.py
View file @
29051439
tilelang/intrinsics/mma_macro_generator.py
View file @
29051439
...
...
@@ -191,6 +191,7 @@ class TensorCoreIntrinEmitter:
def
get_store_index_map
(
self
,
inverse
:
bool
=
False
)
->
IndexMap
:
from
.utils
import
mma_store_index_map
,
mma_store_index_map_fp64
warp_size
,
local_size_c
=
self
.
WARP_SIZE
,
self
.
local_size_out
if
DataType
(
self
.
accum_dtype
).
bits
==
64
:
index_map
=
IndexMap
.
from_func
(
mma_store_index_map_fp64
,
index_dtype
=
"int32"
)
...
...
@@ -201,10 +202,7 @@ class TensorCoreIntrinEmitter:
inverse_index_map
=
index_map
.
inverse
([
warp_size
,
local_size_c
])
return
inverse_index_map
def
extract_thread_binding
(
self
,
thread_id
:
PrimExpr
,
is_m_first
:
bool
|
None
=
None
)
->
tuple
[
PrimExpr
,
PrimExpr
,
PrimExpr
]:
def
extract_thread_binding
(
self
,
thread_id
:
PrimExpr
,
is_m_first
:
bool
|
None
=
None
)
->
tuple
[
PrimExpr
,
PrimExpr
,
PrimExpr
]:
"""
is_m_first: True if the thread binding is in the form of (tx, warp_n, warp_m)
which represents [warp_size, block_row_warps (split n), block_col_warps (split m)]
...
...
@@ -233,11 +231,7 @@ class TensorCoreIntrinEmitter:
)
return
lane_id
,
warp_n
,
warp_m
def
ldmatrix_a
(
self
,
A_local_buf
:
Buffer
,
A_shared_buf
:
Buffer
|
BufferRegion
,
ki
:
PrimExpr
,
rk
:
PrimExpr
|
None
=
0
):
def
ldmatrix_a
(
self
,
A_local_buf
:
Buffer
,
A_shared_buf
:
Buffer
|
BufferRegion
,
ki
:
PrimExpr
,
rk
:
PrimExpr
|
None
=
0
):
# Fast path for fp64: no ldmatrix support, do direct per-lane loads
if
DataType
(
self
.
a_dtype
).
bits
==
64
:
warp_row_tiles
=
self
.
warp_row_tiles
...
...
@@ -324,9 +318,7 @@ class TensorCoreIntrinEmitter:
for
i
in
T
.
serial
(
warp_rows
):
# Assign A_shared_buf_elem
wi
,
wk
=
warp_m
*
warp_row_tiles
+
i
*
micro_size_x
,
rk
*
chunk
+
ki
*
micro_size_k
A_shared_buf_elem
=
A_buf
[
A_base0
+
wk
,
A_base1
+
wi
]
if
a_transposed
else
A_buf
[
A_base0
+
wi
,
A_base1
+
wk
]
A_shared_buf_elem
=
A_buf
[
A_base0
+
wk
,
A_base1
+
wi
]
if
a_transposed
else
A_buf
[
A_base0
+
wi
,
A_base1
+
wk
]
if
ldmatrix_available
:
T
.
ptx_ldmatrix
(
...
...
@@ -343,20 +335,13 @@ class TensorCoreIntrinEmitter:
for
j
in
T
.
serial
(
local_size_a
):
mi
,
mk
=
mma_load_layout
(
tx
,
j
)
if
a_transposed
:
A_local_buf
[
i
*
local_size_a
+
j
]
=
A_buf
[
A_base0
+
wk
+
mk
,
A_base1
+
wi
+
mi
]
A_local_buf
[
i
*
local_size_a
+
j
]
=
A_buf
[
A_base0
+
wk
+
mk
,
A_base1
+
wi
+
mi
]
else
:
A_local_buf
[
i
*
local_size_a
+
j
]
=
A_buf
[
A_base0
+
wi
+
mi
,
A_base1
+
wk
+
mk
]
A_local_buf
[
i
*
local_size_a
+
j
]
=
A_buf
[
A_base0
+
wi
+
mi
,
A_base1
+
wk
+
mk
]
return
_warp_ldmatrix_a
(
A_local_buf
,
A_region
,
ki
,
thread_binding
,
rk
)
def
ldmatrix_b
(
self
,
B_local_buf
:
Buffer
,
B_shared_buf
:
Buffer
|
BufferRegion
,
ki
:
PrimExpr
,
rk
:
PrimExpr
|
None
=
0
):
def
ldmatrix_b
(
self
,
B_local_buf
:
Buffer
,
B_shared_buf
:
Buffer
|
BufferRegion
,
ki
:
PrimExpr
,
rk
:
PrimExpr
|
None
=
0
):
# Fast path for fp64: no ldmatrix support, do direct per-lane loads
if
DataType
(
self
.
b_dtype
).
bits
==
64
:
warp_col_tiles
=
self
.
warp_col_tiles
...
...
@@ -411,7 +396,7 @@ class TensorCoreIntrinEmitter:
B_base0
=
B_region
.
region
[
-
2
].
min
B_base1
=
B_region
.
region
[
-
1
].
min
B_stride_last
=
B_buf
.
shape
[
-
1
]
replicate_b
=
(
self
.
n_dim
==
16
)
replicate_b
=
self
.
n_dim
==
16
# ldmatrix cannot be used for int8 + trans case.
ldmatrix_available
=
not
(
DataType
(
b_dtype
).
bits
!=
16
and
not
b_transposed
)
...
...
@@ -448,9 +433,7 @@ class TensorCoreIntrinEmitter:
)
if
ldmatrix_available
:
B_shared_buf_elem
=
B_buf
[
B_base0
+
wi
,
B_base1
+
wk
]
if
b_transposed
else
B_buf
[
B_base0
+
wk
,
B_base1
+
wi
]
B_shared_buf_elem
=
B_buf
[
B_base0
+
wi
,
B_base1
+
wk
]
if
b_transposed
else
B_buf
[
B_base0
+
wk
,
B_base1
+
wi
]
T
.
ptx_ldmatrix
(
b_dtype
,
...
...
@@ -469,19 +452,13 @@ class TensorCoreIntrinEmitter:
for
j
in
T
.
serial
(
local_size_b
):
mi
,
mk
=
mma_load_layout
(
tx
,
j
)
if
b_transposed
:
B_local_buf
[
i
*
local_size_b
+
j
]
=
B_buf
[
B_base0
+
wi
+
mi
,
B_base1
+
wk
+
mk
]
B_local_buf
[
i
*
local_size_b
+
j
]
=
B_buf
[
B_base0
+
wi
+
mi
,
B_base1
+
wk
+
mk
]
else
:
B_local_buf
[
i
*
local_size_b
+
j
]
=
B_buf
[
B_base0
+
wk
+
mk
,
B_base1
+
wi
+
mi
]
B_local_buf
[
i
*
local_size_b
+
j
]
=
B_buf
[
B_base0
+
wk
+
mk
,
B_base1
+
wi
+
mi
]
return
_warp_ldmatrix_b
(
B_local_buf
,
B_shared_buf
,
ki
,
thread_binding
,
rk
)
def
mma
(
self
,
A_local_buf
:
Buffer
,
B_local_buf
:
Buffer
,
C_local_buf
:
Buffer
,
k_inner
:
PrimExpr
|
None
=
0
):
def
mma
(
self
,
A_local_buf
:
Buffer
,
B_local_buf
:
Buffer
,
C_local_buf
:
Buffer
,
k_inner
:
PrimExpr
|
None
=
0
):
warp_rows
=
self
.
warp_rows
warp_cols
=
self
.
warp_cols
local_size_a
=
self
.
local_size_a
...
...
@@ -492,7 +469,7 @@ class TensorCoreIntrinEmitter:
accum_dtype
=
self
.
accum_dtype
accum_dtype_abbrv
=
self
.
accum_dtype_abbrv
mma_prefix
=
self
.
mma_prefix
replicate_b
=
(
self
.
n_dim
==
16
)
replicate_b
=
self
.
n_dim
==
16
a_is_fragment
=
is_fragment
(
A_local_buf
)
b_is_fragment
=
is_fragment
(
B_local_buf
)
...
...
@@ -532,8 +509,7 @@ class TensorCoreIntrinEmitter:
B_local_buf
.
data
,
b_local_stride
+
j
*
local_size_b
+
lift
(
local_size_b
)
//
2
,
C_local_buf
.
data
,
i
*
warp_cols
*
local_size_out
+
j
*
local_size_out
+
lift
(
local_size_out
)
//
2
,
i
*
warp_cols
*
local_size_out
+
j
*
local_size_out
+
lift
(
local_size_out
)
//
2
,
T
.
bool
(
False
),
# saturate
)
...
...
@@ -568,14 +544,13 @@ class TensorCoreIntrinEmitter:
local_id
=
local_id_o
*
2
+
local_id_i
row
,
col
=
T
.
meta_var
(
mma_store_index_map
(
tx
,
local_id
))
if
C_buf_dims
==
2
:
C_buf
[(
warp_m
*
warp_rows
+
i
)
*
M_DIM
+
row
,
(
warp_n
*
warp_cols
+
j
)
*
n_dim
+
col
]
=
C_local_buf
[
i
*
(
warp_cols
*
local_size_out
)
+
j
*
local_size_out
+
local_id
]
C_buf
[(
warp_m
*
warp_rows
+
i
)
*
M_DIM
+
row
,
(
warp_n
*
warp_cols
+
j
)
*
n_dim
+
col
]
=
C_local_buf
[
i
*
(
warp_cols
*
local_size_out
)
+
j
*
local_size_out
+
local_id
]
else
:
C_buf
[
warp_m
*
warp_rows
+
i
,
warp_n
*
warp_cols
+
j
,
row
,
col
]
=
C_local_buf
[
i
*
(
warp_cols
*
local_size_out
)
+
j
*
local_size_out
+
local_id
]
C_buf
[
warp_m
*
warp_rows
+
i
,
warp_n
*
warp_cols
+
j
,
row
,
col
]
=
C_local_buf
[
i
*
(
warp_cols
*
local_size_out
)
+
j
*
local_size_out
+
local_id
]
@
T
.
macro
def
_warp_stmatrix_global
(
C_local_buf
,
C_buf
,
thread_binding
):
...
...
@@ -588,15 +563,15 @@ class TensorCoreIntrinEmitter:
C_buf
[
(
pid_m
*
BLOCK_M
+
warp_m
*
warp_rows
+
i
)
*
M_DIM
+
row
,
(
pid_n
*
BLOCK_N
+
warp_n
*
warp_cols
+
j
)
*
n_dim
+
col
,
]
=
C_local_buf
[
i
*
warp_cols
*
local_size_out
+
j
*
local_size_out
+
local_id
]
]
=
C_local_buf
[
i
*
warp_cols
*
local_size_out
+
j
*
local_size_out
+
local_id
]
return
(
_warp_stmatrix_global
(
C_local_buf
,
C_buf
,
thread_binding
)
if
is_global
else
_warp_stmatrix_shared
(
C_local_buf
,
C_buf
,
thread_binding
))
return
(
_warp_stmatrix_global
(
C_local_buf
,
C_buf
,
thread_binding
)
if
is_global
else
_warp_stmatrix_shared
(
C_local_buf
,
C_buf
,
thread_binding
)
)
def
make_mma_load_layout
(
self
,
local_buf
:
Buffer
,
matrix
:
Literal
[
"A"
,
"B"
]
=
"A"
)
->
T
.
Fragment
:
def
make_mma_load_layout
(
self
,
local_buf
:
Buffer
,
matrix
:
Literal
[
"A"
,
"B"
]
=
"A"
)
->
T
.
Fragment
:
"""
Create a layout function for storing MMA results into a fragment buffer.
This layout is used in conjunction with `inverse_mma_store_layout` to
...
...
@@ -619,6 +594,7 @@ class TensorCoreIntrinEmitter:
If `local_buf` is not detected to be a fragment buffer.
"""
from
tilelang.utils
import
is_fragment
assert
matrix
in
[
"A"
,
"B"
],
"matrix should be either A or B"
matrix_is_a
:
bool
=
matrix
==
"A"
matrix_is_b
:
bool
=
matrix
==
"B"
...
...
@@ -655,11 +631,9 @@ class TensorCoreIntrinEmitter:
# so the b matrix expected a transposed basic layout
transform_func
:
Callable
=
None
if
matrix_is_a
:
transform_func
=
transform_func_sr_a
if
is_sr_axis_order
else
lambda
i
,
j
:
transform_func_sr_a
(
j
,
i
)
transform_func
=
transform_func_sr_a
if
is_sr_axis_order
else
lambda
i
,
j
:
transform_func_sr_a
(
j
,
i
)
elif
matrix_is_b
:
transform_func
=
transform_func_sr_b
if
is_sr_axis_order
else
lambda
i
,
j
:
transform_func_sr_b
(
j
,
i
)
transform_func
=
transform_func_sr_b
if
is_sr_axis_order
else
lambda
i
,
j
:
transform_func_sr_b
(
j
,
i
)
else
:
raise
ValueError
(
f
"Unsupported matrix
{
matrix
}
"
)
...
...
@@ -706,31 +680,19 @@ class TensorCoreIntrinEmitter:
replicate
=
block_col_warps
if
matrix_is_a
else
block_row_warps
if
is_sr_axis_order
:
warp_fragment
=
base_fragment
.
repeat
([
warp_s
,
warp_r
],
repeat_on_thread
=
False
,
lower_dim_first
=
False
)
warp_fragment
=
base_fragment
.
repeat
([
warp_s
,
warp_r
],
repeat_on_thread
=
False
,
lower_dim_first
=
False
)
if
matrix_is_a
:
block_fragment
=
warp_fragment
.
repeat
([
block_s
,
1
],
repeat_on_thread
=
True
,
lower_dim_first
=
True
).
replicate
(
replicate
)
block_fragment
=
warp_fragment
.
repeat
([
block_s
,
1
],
repeat_on_thread
=
True
,
lower_dim_first
=
True
).
replicate
(
replicate
)
elif
matrix_is_b
:
block_fragment
=
warp_fragment
.
replicate
(
replicate
).
repeat
([
block_s
,
1
],
repeat_on_thread
=
True
,
lower_dim_first
=
True
)
block_fragment
=
warp_fragment
.
replicate
(
replicate
).
repeat
([
block_s
,
1
],
repeat_on_thread
=
True
,
lower_dim_first
=
True
)
else
:
raise
ValueError
(
f
"Unsupported matrix type
{
matrix
}
"
)
else
:
warp_fragment
=
base_fragment
.
repeat
([
warp_r
,
warp_s
],
repeat_on_thread
=
False
,
lower_dim_first
=
True
)
warp_fragment
=
base_fragment
.
repeat
([
warp_r
,
warp_s
],
repeat_on_thread
=
False
,
lower_dim_first
=
True
)
if
matrix_is_a
:
block_fragment
=
warp_fragment
.
repeat
([
1
,
block_s
],
repeat_on_thread
=
True
,
lower_dim_first
=
True
).
replicate
(
replicate
)
block_fragment
=
warp_fragment
.
repeat
([
1
,
block_s
],
repeat_on_thread
=
True
,
lower_dim_first
=
True
).
replicate
(
replicate
)
elif
matrix_is_b
:
block_fragment
=
warp_fragment
.
replicate
(
replicate
).
repeat
([
1
,
block_s
],
repeat_on_thread
=
True
,
lower_dim_first
=
True
)
block_fragment
=
warp_fragment
.
replicate
(
replicate
).
repeat
([
1
,
block_s
],
repeat_on_thread
=
True
,
lower_dim_first
=
True
)
else
:
raise
ValueError
(
f
"Unsupported matrix type
{
matrix
}
"
)
...
...
@@ -761,8 +723,7 @@ class TensorCoreIntrinEmitter:
from
tilelang.utils
import
is_fragment
shape
=
local_buf
.
shape
assert
is_fragment
(
local_buf
),
f
"local_buf
{
local_buf
}
must be a fragment, but got
{
local_buf
.
scope
()
}
"
assert
is_fragment
(
local_buf
),
f
"local_buf
{
local_buf
}
must be a fragment, but got
{
local_buf
.
scope
()
}
"
inverse_mma_store_layout
=
self
.
get_store_index_map
(
inverse
=
True
)
micro_size_x
,
micro_size_y
=
self
.
micro_size_x
,
self
.
micro_size_y
...
...
@@ -954,10 +915,12 @@ class TensorCoreIntrinEmitterWithLadderTransform(TensorCoreIntrinEmitter):
".b16"
,
A_local_buf
.
data
,
i
*
local_size_a
,
T
.
address_of
(
A_shared_buf
[
T
.
address_of
(
A_shared_buf
[
warp_m
*
warp_row_tiles
+
i
*
micro_size_x
,
rk
*
chunk
+
ki
*
micro_size_k
,
]),
]
),
get_ldmatrix_offset
(
"A"
,
tx
,
0
,
stride
,
a_dtype
,
a_transposed
),
)
elif
transform_kind_a
==
TransformKind
.
InterWarpTransform
:
...
...
@@ -1019,10 +982,8 @@ class TensorCoreIntrinEmitterWithLadderTransform(TensorCoreIntrinEmitter):
warp_m
*
warp_rows
+
j
,
rk
*
(
chunk
//
micro_size_k
)
+
ki
,
)
rii
,
rjj
=
(
tx
*
local_size_a
+
local_id
)
//
micro_size_k
,
(
tx
*
local_size_a
+
local_id
)
%
(
micro_size_k
)
A_local_buf
[
j
*
local_size_a
+
local_id
]
=
(
A_shared_buf
[
ri
,
rj
,
rii
,
rjj
])
rii
,
rjj
=
(
tx
*
local_size_a
+
local_id
)
//
micro_size_k
,
(
tx
*
local_size_a
+
local_id
)
%
(
micro_size_k
)
A_local_buf
[
j
*
local_size_a
+
local_id
]
=
A_shared_buf
[
ri
,
rj
,
rii
,
rjj
]
else
:
raise
ValueError
(
"Unsupported TransformKind for Input A"
)
...
...
@@ -1131,12 +1092,11 @@ class TensorCoreIntrinEmitterWithLadderTransform(TensorCoreIntrinEmitter):
warp_n
*
warp_cols
+
j
,
rk
*
(
chunk
//
micro_size_k
)
+
ki
,
)
rii
,
rjj
=
(
tx
*
local_size_dequantize
+
local_id
)
//
(
micro_size_k
//
num_elems_per_byte
),
(
tx
*
local_size_dequantize
+
local_id
)
%
(
micro_size_k
//
num_elems_per_byte
)
B_local_buf
[
j
*
local_size_dequantize
+
local_id
]
=
(
B_shared_buf
[
ri
,
rj
,
rii
,
rjj
])
rii
,
rjj
=
(
(
tx
*
local_size_dequantize
+
local_id
)
//
(
micro_size_k
//
num_elems_per_byte
),
(
tx
*
local_size_dequantize
+
local_id
)
%
(
micro_size_k
//
num_elems_per_byte
),
)
B_local_buf
[
j
*
local_size_dequantize
+
local_id
]
=
B_shared_buf
[
ri
,
rj
,
rii
,
rjj
]
else
:
raise
ValueError
(
"Unsupported TransformKind for Input B"
)
...
...
@@ -1195,7 +1155,6 @@ class TensorCoreIntrinEmitterWithLadderTransform(TensorCoreIntrinEmitter):
class
INT4TensorCoreIntrinEmitter
(
TensorCoreIntrinEmitter
):
def
mma
(
self
,
A_local_buf
,
B_local_buf
,
C_local_buf
):
warp_rows
=
self
.
warp_rows
warp_cols
=
self
.
warp_cols
...
...
@@ -1298,9 +1257,7 @@ class INT4TensorCoreIntrinEmitter(TensorCoreIntrinEmitter):
class
INT4TensorCoreIntrinEmitterWithLadderTransform
(
TensorCoreIntrinEmitterWithLadderTransform
):
def
mma
(
self
,
A_local_buf
,
B_local_buf
,
C_local_buf
):
warp_rows
=
self
.
warp_rows
warp_cols
=
self
.
warp_cols
local_size_a
=
self
.
local_size_a
...
...
tilelang/intrinsics/mma_sm70_layout.py
View file @
29051439
...
...
@@ -17,10 +17,8 @@ def shared_16x4_to_mma_b_32x4_layout_trans(row, col, rep):
def
mma_32x8_to_shared_16x16_layout_fp32
(
thread_id
,
local_id
):
row
=
(
thread_id
%
2
)
+
(
(
local_id
//
2
%
2
)
*
2
)
+
4
*
(
thread_id
//
16
)
+
(
thread_id
%
16
//
4
)
%
2
*
8
col
=
(
thread_id
%
4
//
2
)
*
2
+
(
thread_id
%
16
//
8
)
*
4
+
(
local_id
%
2
)
+
(
local_id
//
4
)
*
8
row
=
(
thread_id
%
2
)
+
((
local_id
//
2
%
2
)
*
2
)
+
4
*
(
thread_id
//
16
)
+
(
thread_id
%
16
//
4
)
%
2
*
8
col
=
(
thread_id
%
4
//
2
)
*
2
+
(
thread_id
%
16
//
8
)
*
4
+
(
local_id
%
2
)
+
(
local_id
//
4
)
*
8
return
row
,
col
...
...
@@ -31,7 +29,7 @@ def mma_32x8_to_shared_16x16_layout_fp16(thread_id, local_id):
def
mma_load_a_32x4_to_shared_16x4_layout
(
thread_id
,
local_id
):
row
=
(
thread_id
%
4
)
+
(
4
*
((
(
thread_id
//
16
+
thread_id
%
16
//
4
*
2
)
)
%
4
))
row
=
(
thread_id
%
4
)
+
(
4
*
((
thread_id
//
16
+
thread_id
%
16
//
4
*
2
)
%
4
))
col
=
local_id
return
row
,
col
...
...
tilelang/intrinsics/mma_sm70_macro_generator.py
View file @
29051439
...
...
@@ -147,18 +147,15 @@ class TensorCoreIntrinEmitter:
def
get_store_index_map
(
self
,
inverse
:
bool
=
False
)
->
IndexMap
:
warp_size
,
local_size_c
=
self
.
WARP_SIZE
,
self
.
local_size_out
index_map
=
IndexMap
.
from_func
(
mma_32x8_to_shared_16x16_layout_fp32
i
f
self
.
accum_dtype
==
"float32"
else
mma_32x8_to_shared_16x16_layout_fp16
,
index_dtype
=
"int32"
)
mma_32x8_to_shared_16x16_layout_fp32
if
self
.
accum_dtype
==
"float32"
else
mma_32x8_to_shared_16x16_layout_fp16
,
i
ndex_dtype
=
"int32"
,
)
if
not
inverse
:
return
index_map
inverse_index_map
=
index_map
.
inverse
([
warp_size
,
local_size_c
])
return
inverse_index_map
def
extract_thread_binding
(
self
,
thread_id
:
PrimExpr
,
is_m_first
:
bool
|
None
=
None
)
->
tuple
[
PrimExpr
,
PrimExpr
,
PrimExpr
]:
def
extract_thread_binding
(
self
,
thread_id
:
PrimExpr
,
is_m_first
:
bool
|
None
=
None
)
->
tuple
[
PrimExpr
,
PrimExpr
,
PrimExpr
]:
"""
is_m_first: True if the thread binding is in the form of (tx, warp_n, warp_m)
which represents [warp_size, block_row_warps (split n), block_col_warps (split m)]
...
...
@@ -187,11 +184,7 @@ class TensorCoreIntrinEmitter:
)
return
lane_id
,
warp_n
,
warp_m
def
ldmatrix_a
(
self
,
A_local_buf
:
Buffer
,
A_shared_buf
:
Buffer
|
BufferRegion
,
ki
:
PrimExpr
,
rk
:
PrimExpr
|
None
=
0
):
def
ldmatrix_a
(
self
,
A_local_buf
:
Buffer
,
A_shared_buf
:
Buffer
|
BufferRegion
,
ki
:
PrimExpr
,
rk
:
PrimExpr
|
None
=
0
):
warp_row_tiles
=
self
.
warp_row_tiles
warp_rows
=
self
.
warp_rows
chunk
=
self
.
chunk
...
...
@@ -231,11 +224,7 @@ class TensorCoreIntrinEmitter:
return
_warp_ldmatrix_a
(
A_local_buf
,
A_region
,
ki
,
thread_binding
,
rk
)
def
ldmatrix_b
(
self
,
B_local_buf
:
Buffer
,
B_shared_buf
:
Buffer
|
BufferRegion
,
ki
:
PrimExpr
,
rk
:
PrimExpr
|
None
=
0
):
def
ldmatrix_b
(
self
,
B_local_buf
:
Buffer
,
B_shared_buf
:
Buffer
|
BufferRegion
,
ki
:
PrimExpr
,
rk
:
PrimExpr
|
None
=
0
):
warp_col_tiles
=
self
.
warp_col_tiles
warp_cols
=
self
.
warp_cols
chunk
=
self
.
chunk
...
...
@@ -274,20 +263,14 @@ class TensorCoreIntrinEmitter:
for
j
in
T
.
vectorized
(
local_size_b
):
if
b_transposed
:
mi
,
mk
=
mma_load_layout
(
tx
,
j
)
B_local_buf
[
i
*
local_size_b
+
j
]
=
B_buf
[
B_base0
+
wi
+
mi
,
B_base1
+
wk
+
mk
]
B_local_buf
[
i
*
local_size_b
+
j
]
=
B_buf
[
B_base0
+
wi
+
mi
,
B_base1
+
wk
+
mk
]
else
:
mk
,
mi
=
mma_load_layout
(
tx
,
j
)
B_local_buf
[
i
*
local_size_b
+
j
]
=
B_buf
[
B_base0
+
wk
+
mk
,
B_base1
+
wi
+
mi
]
B_local_buf
[
i
*
local_size_b
+
j
]
=
B_buf
[
B_base0
+
wk
+
mk
,
B_base1
+
wi
+
mi
]
return
_warp_ldmatrix_b
(
B_local_buf
,
B_region
,
ki
,
thread_binding
,
rk
)
def
mma
(
self
,
A_local_buf
:
Buffer
,
B_local_buf
:
Buffer
,
C_local_buf
:
Buffer
,
k_inner
:
PrimExpr
|
None
=
0
):
def
mma
(
self
,
A_local_buf
:
Buffer
,
B_local_buf
:
Buffer
,
C_local_buf
:
Buffer
,
k_inner
:
PrimExpr
|
None
=
0
):
warp_rows
=
self
.
warp_rows
warp_cols
=
self
.
warp_cols
local_size_a
=
self
.
local_size_a
...
...
@@ -326,9 +309,7 @@ class TensorCoreIntrinEmitter:
return
_warp_mma
(
A_local_buf
,
B_local_buf
,
C_local_buf
)
def
make_mma_load_layout
(
self
,
local_buf
:
Buffer
,
matrix
:
Literal
[
"A"
,
"B"
]
=
"A"
)
->
T
.
Fragment
:
def
make_mma_load_layout
(
self
,
local_buf
:
Buffer
,
matrix
:
Literal
[
"A"
,
"B"
]
=
"A"
)
->
T
.
Fragment
:
"""
Create a layout function for storing MMA results into a fragment buffer.
This layout is used in conjunction with `inverse_mma_store_layout` to
...
...
@@ -351,6 +332,7 @@ class TensorCoreIntrinEmitter:
If `local_buf` is not detected to be a fragment buffer.
"""
from
tilelang.utils
import
is_fragment
assert
matrix
in
[
"A"
,
"B"
],
"matrix should be either A or B"
matrix_is_a
:
bool
=
matrix
==
"A"
matrix_is_b
:
bool
=
matrix
==
"B"
...
...
@@ -383,11 +365,9 @@ class TensorCoreIntrinEmitter:
# so the b matrix expected a transposed basic layout
transform_func
:
Callable
=
None
if
matrix_is_a
:
transform_func
=
transform_func_sr_a
if
is_sr_axis_order
else
lambda
i
,
j
:
transform_func_sr_a
(
j
,
i
)
transform_func
=
transform_func_sr_a
if
is_sr_axis_order
else
lambda
i
,
j
:
transform_func_sr_a
(
j
,
i
)
elif
matrix_is_b
:
transform_func
=
transform_func_sr_b
if
is_sr_axis_order
else
lambda
i
,
j
:
transform_func_rs_b
(
i
,
j
)
transform_func
=
transform_func_sr_b
if
is_sr_axis_order
else
lambda
i
,
j
:
transform_func_rs_b
(
i
,
j
)
else
:
raise
ValueError
(
f
"Unsupported matrix
{
matrix
}
"
)
...
...
@@ -413,9 +393,8 @@ class TensorCoreIntrinEmitter:
return
lane_id
,
local_id
base_fragment
=
T
.
Fragment
(
[
micro_size_s
,
micro_size_r
]
if
is_sr_axis_order
else
[
micro_size_r
,
micro_size_s
],
forward_fn
=
forward
,
replicate
=
2
)
[
micro_size_s
,
micro_size_r
]
if
is_sr_axis_order
else
[
micro_size_r
,
micro_size_s
],
forward_fn
=
forward
,
replicate
=
2
)
warp_rows
,
warp_cols
=
self
.
warp_rows
,
self
.
warp_cols
chunk
=
self
.
chunk
...
...
@@ -426,31 +405,19 @@ class TensorCoreIntrinEmitter:
replicate
=
block_col_warps
if
matrix_is_a
else
block_row_warps
if
is_sr_axis_order
:
warp_fragment
=
base_fragment
.
repeat
([
warp_s
,
warp_r
],
repeat_on_thread
=
False
,
lower_dim_first
=
False
)
warp_fragment
=
base_fragment
.
repeat
([
warp_s
,
warp_r
],
repeat_on_thread
=
False
,
lower_dim_first
=
False
)
if
matrix_is_a
:
block_fragment
=
warp_fragment
.
repeat
([
block_s
,
1
],
repeat_on_thread
=
True
,
lower_dim_first
=
True
).
replicate
(
replicate
)
block_fragment
=
warp_fragment
.
repeat
([
block_s
,
1
],
repeat_on_thread
=
True
,
lower_dim_first
=
True
).
replicate
(
replicate
)
elif
matrix_is_b
:
block_fragment
=
warp_fragment
.
replicate
(
replicate
).
repeat
([
block_s
,
1
],
repeat_on_thread
=
True
,
lower_dim_first
=
True
)
block_fragment
=
warp_fragment
.
replicate
(
replicate
).
repeat
([
block_s
,
1
],
repeat_on_thread
=
True
,
lower_dim_first
=
True
)
else
:
raise
ValueError
(
f
"Unsupported matrix type
{
matrix
}
"
)
else
:
warp_fragment
=
base_fragment
.
repeat
([
warp_r
,
warp_s
],
repeat_on_thread
=
False
,
lower_dim_first
=
True
)
warp_fragment
=
base_fragment
.
repeat
([
warp_r
,
warp_s
],
repeat_on_thread
=
False
,
lower_dim_first
=
True
)
if
matrix_is_a
:
block_fragment
=
warp_fragment
.
repeat
([
1
,
block_s
],
repeat_on_thread
=
True
,
lower_dim_first
=
True
).
replicate
(
replicate
)
block_fragment
=
warp_fragment
.
repeat
([
1
,
block_s
],
repeat_on_thread
=
True
,
lower_dim_first
=
True
).
replicate
(
replicate
)
elif
matrix_is_b
:
block_fragment
=
warp_fragment
.
replicate
(
replicate
).
repeat
([
1
,
block_s
],
repeat_on_thread
=
True
,
lower_dim_first
=
True
)
block_fragment
=
warp_fragment
.
replicate
(
replicate
).
repeat
([
1
,
block_s
],
repeat_on_thread
=
True
,
lower_dim_first
=
True
)
else
:
raise
ValueError
(
f
"Unsupported matrix type
{
matrix
}
"
)
...
...
tilelang/intrinsics/mma_sp_layout.py
View file @
29051439
...
...
@@ -72,56 +72,47 @@ def get_logical_id_32bit(thread_id: int) -> int:
return
(
thread_id
//
4
)
*
2
+
(
thread_id
%
4
)
%
2
def
metadata_8bit_load_32x4_to_shared_16x4_layout_32bit
(
thread_id
:
int
,
local_id
:
int
)
->
tuple
[
int
,
int
]:
def
metadata_8bit_load_32x4_to_shared_16x4_layout_32bit
(
thread_id
:
int
,
local_id
:
int
)
->
tuple
[
int
,
int
]:
logical_id
=
get_logical_id_32bit
(
thread_id
)
row
=
logical_id
//
4
+
local_id
*
8
col
=
logical_id
%
4
return
row
,
col
def
metadata_16bit_load_32x2_to_shared_16x2_layout_32bit
(
thread_id
:
int
,
local_id
:
int
)
->
tuple
[
int
,
int
]:
def
metadata_16bit_load_32x2_to_shared_16x2_layout_32bit
(
thread_id
:
int
,
local_id
:
int
)
->
tuple
[
int
,
int
]:
logical_id
=
get_logical_id_32bit
(
thread_id
)
row
=
logical_id
//
2
+
local_id
*
8
col
=
logical_id
%
2
return
row
,
col
def
metadata_8bit_load_32x4_to_shared_16x4_layout_16bit
(
thread_id
:
int
,
local_id
:
int
)
->
tuple
[
int
,
int
]:
return
metadata_8bit_load_32x4_to_shared_16x4_layout_32bit
(
thread_id
,
local_id
)
# same mapping for 16bit and 32bit
def
metadata_8bit_load_32x4_to_shared_16x4_layout_16bit
(
thread_id
:
int
,
local_id
:
int
)
->
tuple
[
int
,
int
]:
return
metadata_8bit_load_32x4_to_shared_16x4_layout_32bit
(
thread_id
,
local_id
)
# same mapping for 16bit and 32bit
def
metadata_16bit_load_32x2_to_shared_16x2_layout_16bit
(
thread_id
:
int
,
local_id
:
int
)
->
tuple
[
int
,
int
]:
return
metadata_16bit_load_32x2_to_shared_16x2_layout_32bit
(
thread_id
,
local_id
)
# same mapping for 16bit and 32bit
def
metadata_16bit_load_32x2_to_shared_16x2_layout_16bit
(
thread_id
:
int
,
local_id
:
int
)
->
tuple
[
int
,
int
]:
return
metadata_16bit_load_32x2_to_shared_16x2_layout_32bit
(
thread_id
,
local_id
)
# same mapping for 16bit and 32bit
def
get_logical_id_8bit
(
thread_id
:
int
)
->
int
:
return
thread_id
def
metadata_8bit_load_32x4_to_shared_16x4_layout_8bit
(
thread_id
:
int
,
local_id
:
int
)
->
tuple
[
int
,
int
]:
def
metadata_8bit_load_32x4_to_shared_16x4_layout_8bit
(
thread_id
:
int
,
local_id
:
int
)
->
tuple
[
int
,
int
]:
logical_id
=
get_logical_id_8bit
(
thread_id
)
row
=
logical_id
//
2
+
local_id
*
8
col
=
(
logical_id
%
4
)
//
2
*
4
+
local_id
return
row
,
col
def
metadata_16bit_load_32x2_to_shared_16x4_layout_8bit
(
thread_id
:
int
,
local_id
:
int
)
->
tuple
[
int
,
int
]:
def
metadata_16bit_load_32x2_to_shared_16x4_layout_8bit
(
thread_id
:
int
,
local_id
:
int
)
->
tuple
[
int
,
int
]:
logical_id
=
get_logical_id_8bit
(
thread_id
)
row
=
logical_id
//
2
+
local_id
*
8
col
=
(
logical_id
%
4
)
//
2
*
2
+
local_id
return
row
,
col
def
metadata_32bit_load_32x1_to_shared_16x2_layout_8bit
(
thread_id
:
int
,
local_id
:
int
)
->
tuple
[
int
,
int
]:
def
metadata_32bit_load_32x1_to_shared_16x2_layout_8bit
(
thread_id
:
int
,
local_id
:
int
)
->
tuple
[
int
,
int
]:
# local_id is always 0
logical_id
=
get_logical_id_8bit
(
thread_id
)
row
=
logical_id
//
4
+
(
logical_id
%
2
)
*
8
...
...
tilelang/intrinsics/mma_sp_macro_generator.py
View file @
29051439
...
...
@@ -190,8 +190,7 @@ class SparseTensorCoreIntrinEmitter:
def
_initialize_local_size
(
self
,
m_dim
=
16
,
n_dim
=
16
,
k_dim
=
16
,
warp_size
=
32
):
self
.
local_size_a
=
(
m_dim
*
k_dim
)
//
warp_size
//
self
.
SPARSE_FACTOR
self
.
local_size_e
=
(
m_dim
*
k_dim
)
//
self
.
e_factor
//
warp_size
*
self
.
E_REPLICATE_FACTOR
[
self
.
a_dtype
]
self
.
local_size_e
=
(
m_dim
*
k_dim
)
//
self
.
e_factor
//
warp_size
*
self
.
E_REPLICATE_FACTOR
[
self
.
a_dtype
]
self
.
local_size_b
=
(
n_dim
*
k_dim
)
//
warp_size
self
.
local_size_out
=
(
m_dim
*
n_dim
)
//
warp_size
...
...
@@ -257,10 +256,7 @@ class SparseTensorCoreIntrinEmitter:
inverse_index_map
=
index_map
.
inverse
([
warp_size
,
local_size_c
])
return
inverse_index_map
def
extract_thread_binding
(
self
,
thread_id
:
PrimExpr
,
is_m_first
:
bool
|
None
=
None
)
->
tuple
[
PrimExpr
,
PrimExpr
,
PrimExpr
]:
def
extract_thread_binding
(
self
,
thread_id
:
PrimExpr
,
is_m_first
:
bool
|
None
=
None
)
->
tuple
[
PrimExpr
,
PrimExpr
,
PrimExpr
]:
"""
is_m_first: True if the thread binding is in the form of (tx, warp_n, warp_m)
which represents [warp_size, block_row_warps (split n), block_col_warps (split m)]
...
...
@@ -330,8 +326,7 @@ class SparseTensorCoreIntrinEmitter:
for
i
in
T
.
serial
(
warp_rows
):
# Assign A_shared_buf_elem
wi
,
wk
=
warp_m
*
warp_row_tiles
+
i
*
micro_size_x
,
(
rk
*
warp_k
+
ki
*
micro_size_k
)
//
self
.
SPARSE_FACTOR
wi
,
wk
=
warp_m
*
warp_row_tiles
+
i
*
micro_size_x
,
(
rk
*
warp_k
+
ki
*
micro_size_k
)
//
self
.
SPARSE_FACTOR
A_shared_buf_elem
=
A_shared_buf
[
wk
,
wi
]
if
a_transposed
else
A_shared_buf
[
wi
,
wk
]
if
ldmatrix_available
:
...
...
@@ -348,10 +343,9 @@ class SparseTensorCoreIntrinEmitter:
else
:
for
j
in
T
.
serial
(
local_size_a
):
mi
,
mk
=
mma_load_layout
(
tx
,
j
)
A_local_buf
[
i
*
local_size_a
+
j
]
=
A_shared_buf
[
wk
+
mk
,
wi
+
mi
]
if
a_transposed
else
A_shared_buf
[
wi
+
mi
,
wk
+
mk
]
A_local_buf
[
i
*
local_size_a
+
j
]
=
(
A_shared_buf
[
wk
+
mk
,
wi
+
mi
]
if
a_transposed
else
A_shared_buf
[
wi
+
mi
,
wk
+
mk
]
)
return
_warp_ldmatrix_a
(
A_local_buf
,
A_shared_buf
,
ki
,
thread_binding
,
rk
)
...
...
@@ -412,14 +406,10 @@ class SparseTensorCoreIntrinEmitter:
tx
,
_
,
warp_m
=
self
.
extract_thread_binding
(
thread_binding
)
for
i
in
T
.
serial
(
warp_rows
):
# Assign E_shared_buf_elem
wi
,
wk
=
warp_m
*
warp_row_tiles
+
i
*
micro_size_x
,
(
rk
*
warp_k
+
ki
*
micro_size_k
)
//
self
.
e_factor
wi
,
wk
=
warp_m
*
warp_row_tiles
+
i
*
micro_size_x
,
(
rk
*
warp_k
+
ki
*
micro_size_k
)
//
self
.
e_factor
for
j
in
T
.
serial
(
local_size_e
):
mi
,
mk
=
mma_load_layout
(
tx
,
j
)
E_local_buf
[
i
*
local_size_e
+
j
]
=
E_shared_buf
[
wk
+
mk
,
wi
+
mi
]
if
trans
else
E_shared_buf
[
wi
+
mi
,
wk
+
mk
]
E_local_buf
[
i
*
local_size_e
+
j
]
=
E_shared_buf
[
wk
+
mk
,
wi
+
mi
]
if
trans
else
E_shared_buf
[
wi
+
mi
,
wk
+
mk
]
return
_warp_ldmatrix_e
(
E_local_buf
,
E_shared_buf
,
ki
,
thread_binding
,
rk
)
...
...
@@ -433,7 +423,7 @@ class SparseTensorCoreIntrinEmitter:
b_dtype
=
self
.
b_dtype
b_transposed
=
self
.
b_transposed
thread_binding
=
self
.
get_thread_binding
()
replicate_b
=
(
self
.
n_dim
==
16
)
replicate_b
=
self
.
n_dim
==
16
# ldmatrix cannot be used for int8 + trans case.
ldmatrix_available
=
not
(
DataType
(
b_dtype
).
bits
!=
16
and
not
b_transposed
)
...
...
@@ -470,8 +460,7 @@ class SparseTensorCoreIntrinEmitter:
)
if
ldmatrix_available
:
B_shared_buf_elem
=
B_shared_buf
[
wi
,
wk
]
if
b_transposed
else
B_shared_buf
[
wk
,
wi
]
B_shared_buf_elem
=
B_shared_buf
[
wi
,
wk
]
if
b_transposed
else
B_shared_buf
[
wk
,
wi
]
if
replicate_b
:
T
.
ptx_ldmatrix
(
...
...
@@ -493,9 +482,7 @@ class SparseTensorCoreIntrinEmitter:
B_local_buf
.
data
,
i
*
local_size_b
+
lift
(
local_size_b
)
//
2
,
T
.
address_of
(
B_shared_buf_elem
),
get_ldmatrix_offset_b
(
"B"
,
tx
,
lift
(
local_size_b
)
//
2
,
stride
,
b_dtype
,
b_transposed
),
get_ldmatrix_offset_b
(
"B"
,
tx
,
lift
(
local_size_b
)
//
2
,
stride
,
b_dtype
,
b_transposed
),
)
else
:
T
.
ptx_ldmatrix
(
...
...
@@ -514,19 +501,13 @@ class SparseTensorCoreIntrinEmitter:
# must be transposed.
for
j
in
T
.
serial
(
local_size_b
):
mi
,
mk
=
mma_load_layout
(
tx
,
j
)
B_local_buf
[
i
*
local_size_b
+
j
]
=
B_shared_buf
[
wi
+
mi
,
wk
+
mk
]
if
b_transposed
else
B_shared_buf
[
wk
+
mk
,
wi
+
mi
]
B_local_buf
[
i
*
local_size_b
+
j
]
=
(
B_shared_buf
[
wi
+
mi
,
wk
+
mk
]
if
b_transposed
else
B_shared_buf
[
wk
+
mk
,
wi
+
mi
]
)
return
_warp_ldmatrix_b
(
B_local_buf
,
B_shared_buf
,
ki
,
thread_binding
,
rk
)
def
mma_sp
(
self
,
A_local_buf
:
Buffer
,
E_local_buf
:
Buffer
,
B_local_buf
:
Buffer
,
C_local_buf
:
Buffer
,
k_inner
:
PrimExpr
=
0
):
def
mma_sp
(
self
,
A_local_buf
:
Buffer
,
E_local_buf
:
Buffer
,
B_local_buf
:
Buffer
,
C_local_buf
:
Buffer
,
k_inner
:
PrimExpr
=
0
):
warp_rows
=
self
.
warp_rows
warp_cols
=
self
.
warp_cols
local_size_a
=
self
.
local_size_a
...
...
@@ -538,7 +519,7 @@ class SparseTensorCoreIntrinEmitter:
accum_dtype
=
self
.
accum_dtype
accum_dtype_abbrv
=
self
.
accum_dtype_abbrv
mma_prefix
=
self
.
mma_prefix
replicate_b
=
(
self
.
n_dim
==
16
)
replicate_b
=
self
.
n_dim
==
16
a_is_fragment
=
is_fragment
(
A_local_buf
)
e_is_fragment
=
is_fragment
(
E_local_buf
)
...
...
@@ -584,8 +565,7 @@ class SparseTensorCoreIntrinEmitter:
B_local_buf
.
data
,
b_local_stride
+
j
*
local_size_b
+
lift
(
local_size_b
)
//
2
,
C_local_buf
.
data
,
i
*
warp_cols
*
local_size_out
+
j
*
local_size_out
+
lift
(
local_size_out
)
//
2
,
i
*
warp_cols
*
local_size_out
+
j
*
local_size_out
+
lift
(
local_size_out
)
//
2
,
E_local_buf
.
data
,
# metadata
e_local_stride
+
i
*
local_size_e
,
# metadata offset
self
.
SPARSE_SELECTOR
,
# sparse_selector
...
...
@@ -623,14 +603,13 @@ class SparseTensorCoreIntrinEmitter:
local_id
=
local_id_o
*
2
+
local_id_i
row
,
col
=
T
.
meta_var
(
mma_store_index_map
(
tx
,
local_id
))
if
C_buf_dims
==
2
:
C_buf
[(
warp_m
*
warp_rows
+
i
)
*
M_DIM
+
row
,
(
warp_n
*
warp_cols
+
j
)
*
n_dim
+
col
]
=
C_local_buf
[
i
*
(
warp_cols
*
local_size_out
)
+
j
*
local_size_out
+
local_id
]
C_buf
[(
warp_m
*
warp_rows
+
i
)
*
M_DIM
+
row
,
(
warp_n
*
warp_cols
+
j
)
*
n_dim
+
col
]
=
C_local_buf
[
i
*
(
warp_cols
*
local_size_out
)
+
j
*
local_size_out
+
local_id
]
else
:
C_buf
[
warp_m
*
warp_rows
+
i
,
warp_n
*
warp_cols
+
j
,
row
,
col
]
=
C_local_buf
[
i
*
(
warp_cols
*
local_size_out
)
+
j
*
local_size_out
+
local_id
]
C_buf
[
warp_m
*
warp_rows
+
i
,
warp_n
*
warp_cols
+
j
,
row
,
col
]
=
C_local_buf
[
i
*
(
warp_cols
*
local_size_out
)
+
j
*
local_size_out
+
local_id
]
@
T
.
macro
def
_warp_stmatrix_global
(
C_local_buf
,
C_buf
,
thread_binding
):
...
...
@@ -643,15 +622,15 @@ class SparseTensorCoreIntrinEmitter:
C_buf
[
(
pid_m
*
BLOCK_M
+
warp_m
*
warp_rows
+
i
)
*
M_DIM
+
row
,
(
pid_n
*
BLOCK_N
+
warp_n
*
warp_cols
+
j
)
*
n_dim
+
col
,
]
=
C_local_buf
[
i
*
warp_cols
*
local_size_out
+
j
*
local_size_out
+
local_id
]
]
=
C_local_buf
[
i
*
warp_cols
*
local_size_out
+
j
*
local_size_out
+
local_id
]
return
(
_warp_stmatrix_global
(
C_local_buf
,
C_buf
,
thread_binding
)
if
is_global
else
_warp_stmatrix_shared
(
C_local_buf
,
C_buf
,
thread_binding
))
return
(
_warp_stmatrix_global
(
C_local_buf
,
C_buf
,
thread_binding
)
if
is_global
else
_warp_stmatrix_shared
(
C_local_buf
,
C_buf
,
thread_binding
)
)
def
make_mma_load_layout
(
self
,
local_buf
:
Buffer
,
matrix
:
Literal
[
"A"
,
"B"
]
=
"A"
)
->
T
.
Fragment
:
def
make_mma_load_layout
(
self
,
local_buf
:
Buffer
,
matrix
:
Literal
[
"A"
,
"B"
]
=
"A"
)
->
T
.
Fragment
:
"""
Create a layout function for storing MMA results into a fragment buffer.
This layout is used in conjunction with `inverse_mma_store_layout` to
...
...
@@ -674,6 +653,7 @@ class SparseTensorCoreIntrinEmitter:
If `local_buf` is not detected to be a fragment buffer.
"""
from
tilelang.utils
import
is_fragment
assert
matrix
in
[
"A"
,
"B"
],
"matrix should be either A or B"
matrix_is_a
:
bool
=
matrix
==
"A"
matrix_is_b
:
bool
=
matrix
==
"B"
...
...
@@ -710,11 +690,9 @@ class SparseTensorCoreIntrinEmitter:
# so the b matrix expected a transposed basic layout
transform_func
:
Callable
=
None
if
matrix_is_a
:
transform_func
=
transform_func_sr_a
if
is_sr_axis_order
else
lambda
i
,
j
:
transform_func_sr_a
(
j
,
i
)
transform_func
=
transform_func_sr_a
if
is_sr_axis_order
else
lambda
i
,
j
:
transform_func_sr_a
(
j
,
i
)
elif
matrix_is_b
:
transform_func
=
transform_func_sr_b
if
is_sr_axis_order
else
lambda
i
,
j
:
transform_func_sr_b
(
j
,
i
)
transform_func
=
transform_func_sr_b
if
is_sr_axis_order
else
lambda
i
,
j
:
transform_func_sr_b
(
j
,
i
)
else
:
raise
ValueError
(
f
"Unsupported matrix
{
matrix
}
"
)
...
...
@@ -747,7 +725,8 @@ class SparseTensorCoreIntrinEmitter:
return
local_id
base_fragment
=
T
.
Fragment
(
[
micro_size_s
,
micro_size_r
//
2
if
matrix_is_a
else
micro_size_r
]
if
is_sr_axis_order
[
micro_size_s
,
micro_size_r
//
2
if
matrix_is_a
else
micro_size_r
]
if
is_sr_axis_order
else
[
micro_size_r
//
2
if
matrix_is_a
else
micro_size_r
,
micro_size_s
],
forward_thread_fn
=
forward_thread
,
forward_index_fn
=
forward_index
,
...
...
@@ -762,31 +741,19 @@ class SparseTensorCoreIntrinEmitter:
replicate
=
block_col_warps
if
matrix_is_a
else
block_row_warps
if
is_sr_axis_order
:
warp_fragment
=
base_fragment
.
repeat
([
warp_s
,
warp_r
],
repeat_on_thread
=
False
,
lower_dim_first
=
False
)
warp_fragment
=
base_fragment
.
repeat
([
warp_s
,
warp_r
],
repeat_on_thread
=
False
,
lower_dim_first
=
False
)
if
matrix_is_a
:
block_fragment
=
warp_fragment
.
repeat
([
block_s
,
1
],
repeat_on_thread
=
True
,
lower_dim_first
=
True
).
replicate
(
replicate
)
block_fragment
=
warp_fragment
.
repeat
([
block_s
,
1
],
repeat_on_thread
=
True
,
lower_dim_first
=
True
).
replicate
(
replicate
)
elif
matrix_is_b
:
block_fragment
=
warp_fragment
.
replicate
(
replicate
).
repeat
([
block_s
,
1
],
repeat_on_thread
=
True
,
lower_dim_first
=
True
)
block_fragment
=
warp_fragment
.
replicate
(
replicate
).
repeat
([
block_s
,
1
],
repeat_on_thread
=
True
,
lower_dim_first
=
True
)
else
:
raise
ValueError
(
f
"Unsupported matrix type
{
matrix
}
"
)
else
:
warp_fragment
=
base_fragment
.
repeat
([
warp_r
,
warp_s
],
repeat_on_thread
=
False
,
lower_dim_first
=
True
)
warp_fragment
=
base_fragment
.
repeat
([
warp_r
,
warp_s
],
repeat_on_thread
=
False
,
lower_dim_first
=
True
)
if
matrix_is_a
:
block_fragment
=
warp_fragment
.
repeat
([
1
,
block_s
],
repeat_on_thread
=
True
,
lower_dim_first
=
True
).
replicate
(
replicate
)
block_fragment
=
warp_fragment
.
repeat
([
1
,
block_s
],
repeat_on_thread
=
True
,
lower_dim_first
=
True
).
replicate
(
replicate
)
elif
matrix_is_b
:
block_fragment
=
warp_fragment
.
replicate
(
replicate
).
repeat
([
1
,
block_s
],
repeat_on_thread
=
True
,
lower_dim_first
=
True
)
block_fragment
=
warp_fragment
.
replicate
(
replicate
).
repeat
([
1
,
block_s
],
repeat_on_thread
=
True
,
lower_dim_first
=
True
)
else
:
raise
ValueError
(
f
"Unsupported matrix type
{
matrix
}
"
)
...
...
tilelang/intrinsics/tcgen05_macro_generator.py
View file @
29051439
...
...
@@ -88,9 +88,22 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
is_m_first
:
bool
=
False
,
thread_var
:
Var
|
None
=
None
,
):
super
().
__init__
(
a_dtype
,
b_dtype
,
accum_dtype
,
a_transposed
,
b_transposed
,
block_row_warps
,
block_col_warps
,
warp_row_tiles
,
warp_col_tiles
,
chunk
,
reduce_k
,
num_elems_per_byte
,
is_m_first
,
thread_var
)
super
().
__init__
(
a_dtype
,
b_dtype
,
accum_dtype
,
a_transposed
,
b_transposed
,
block_row_warps
,
block_col_warps
,
warp_row_tiles
,
warp_col_tiles
,
chunk
,
reduce_k
,
num_elems_per_byte
,
is_m_first
,
thread_var
,
)
def
_assign_a_shared_layout
(
self
,
layout
:
Layout
):
self
.
a_shared_layout
=
layout
...
...
@@ -137,13 +150,7 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
else
:
raise
ValueError
(
f
"Unsupported swizzle mode:
{
layout
}
"
)
def
tcgen05mma
(
self
,
A_buf
:
Buffer
,
B_buf
:
Buffer
,
C_local_buf
:
Buffer
,
mbar
,
clear_accum
:
PrimExpr
=
False
):
def
tcgen05mma
(
self
,
A_buf
:
Buffer
,
B_buf
:
Buffer
,
C_local_buf
:
Buffer
,
mbar
,
clear_accum
:
PrimExpr
=
False
):
if
is_tensor_memory
(
A_buf
):
return
self
.
tcgen05mma_rs
(
A_buf
,
B_buf
,
C_local_buf
,
clear_accum
)
...
...
@@ -164,22 +171,20 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
elems_in_bits
=
DataType
(
self
.
a_dtype
).
bits
elems_in_bytes
=
elems_in_bits
//
8
a_swizzle_atom_elems
=
a_swizzle_mode
.
swizzle_byte_size
()
//
elems_in_bytes
b_swizzle_atom_elems
=
n_dim
if
b_swizzle_mode
.
is_none
(
)
else
b_swizzle_mode
.
swizzle_byte_size
()
//
elems_in_bytes
b_swizzle_atom_elems
=
n_dim
if
b_swizzle_mode
.
is_none
()
else
b_swizzle_mode
.
swizzle_byte_size
()
//
elems_in_bytes
accum_dtype_in_bits
=
DataType
(
accum_dtype
).
bits
meta
=
self
.
get_tcgen5_mma_meta
(
m_dim
,
n_dim
,
k_dim
)
if
len
(
meta
)
!=
5
:
raise
ValueError
(
f
"Unsupported TCGEN5MMA configuration for desc generation: M=
{
m_dim
}
, N=
{
n_dim
}
, "
f
"K=
{
k_dim
}
, A dtype=
{
self
.
a_dtype
}
, accum dtype=
{
self
.
accum_dtype
}
"
)
f
"K=
{
k_dim
}
, A dtype=
{
self
.
a_dtype
}
, accum dtype=
{
self
.
accum_dtype
}
"
)
atom_m
,
atom_n
,
atom_k
,
enable_ws
,
enable_2cta
=
(
int
(
x
)
for
x
in
meta
)
# by default, we utilize non-swizzle layout offset
a_leading_byte_offset
=
(
8
*
8
*
elems_in_bytes
)
if
a_is_k_major
else
(
8
*
m_dim
*
elems_in_bytes
)
a_stride_byte_offset
=
(
8
*
k_dim
*
elems_in_bytes
)
if
a_is_k_major
else
(
8
*
8
*
elems_in_bytes
)
a_leading_byte_offset
=
(
8
*
8
*
elems_in_bytes
)
if
a_is_k_major
else
(
8
*
m_dim
*
elems_in_bytes
)
a_stride_byte_offset
=
(
8
*
k_dim
*
elems_in_bytes
)
if
a_is_k_major
else
(
8
*
8
*
elems_in_bytes
)
if
not
a_swizzle_mode
.
is_none
():
# swizzle mode doesn't require LBO/SBO to be 1
...
...
@@ -202,11 +207,8 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
else
:
a_stride_byte_offset
=
8
*
elems_in_bytes
*
a_swizzle_atom_elems
b_leading_byte_offset
=
(
8
*
8
*
elems_in_bytes
)
if
b_is_k_major
else
(
8
*
n_dim
*
elems_in_bytes
)
b_stride_byte_offset
=
(
8
*
k_dim
*
elems_in_bytes
)
if
b_is_k_major
else
(
0
if
n_dim
==
8
else
(
8
*
8
*
elems_in_bytes
))
b_leading_byte_offset
=
(
8
*
8
*
elems_in_bytes
)
if
b_is_k_major
else
(
8
*
n_dim
*
elems_in_bytes
)
b_stride_byte_offset
=
(
8
*
k_dim
*
elems_in_bytes
)
if
b_is_k_major
else
(
0
if
n_dim
==
8
else
(
8
*
8
*
elems_in_bytes
))
if
not
b_swizzle_mode
.
is_none
():
# swizzle mode doesn't require LBO/SBO to be 1
# https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-leading-dimension-byte-offset
...
...
@@ -312,21 +314,26 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
for
ki
in
T
.
unroll
(
0
,
(
k_dim
//
micro_size_k
)):
scale_out
=
T
.
Select
(
ki
!=
0
,
1
,
T
.
Select
(
clear_accum
,
0
,
1
))
A_elem_offset
=
(
ki
%
ak_atom_size
)
*
micro_size_k
+
i
*
atom_m
*
a_swizzle_atom_elems
+
(
ki
//
ak_atom_size
)
*
m_dim
*
a_swizzle_atom_elems
if
a_is_k_major
else
i
*
atom_m
*
k_dim
+
ki
*
a_swizzle_atom_elems
*
micro_size_k
(
ki
%
ak_atom_size
)
*
micro_size_k
+
i
*
atom_m
*
a_swizzle_atom_elems
+
(
ki
//
ak_atom_size
)
*
m_dim
*
a_swizzle_atom_elems
if
a_is_k_major
else
i
*
atom_m
*
k_dim
+
ki
*
a_swizzle_atom_elems
*
micro_size_k
)
B_elem_offset
=
(
ki
//
bk_atom_size
)
*
n_dim
*
b_swizzle_atom_elems
+
(
ki
%
bk_atom_size
)
*
micro_size_k
+
j
*
atom_n
*
b_swizzle_atom_elems
if
b_is_k_major
else
(
ki
*
b_swizzle_atom_elems
*
micro_size_k
+
j
*
atom_n
*
(
k_dim
if
n_dim
//
b_swizzle_atom_elems
>
1
else
1
))
B_elem_offset
=
(
(
ki
//
bk_atom_size
)
*
n_dim
*
b_swizzle_atom_elems
+
(
ki
%
bk_atom_size
)
*
micro_size_k
+
j
*
atom_n
*
b_swizzle_atom_elems
if
b_is_k_major
else
(
ki
*
b_swizzle_atom_elems
*
micro_size_k
+
j
*
atom_n
*
(
k_dim
if
n_dim
//
b_swizzle_atom_elems
>
1
else
1
)
)
)
A_byte_offset
=
A_elem_offset
*
elems_in_bytes
B_byte_offset
=
B_elem_offset
*
elems_in_bytes
C_offset
=
(
i
*
n_dim
+
j
*
tmem_col_step
)
*
accum_dtype_in_bits
//
32
# 32 bits per tmem bank
C_offset
=
(
i
*
n_dim
+
j
*
tmem_col_step
)
*
accum_dtype_in_bits
//
32
# 32 bits per tmem bank
T
.
ptx_tcgen05_mma_ss
(
a_dtype_abbrv
,
...
...
@@ -373,8 +380,7 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
"""
assert
is_tensor_memory
(
tmem_buf
),
"tmem_buf must reside in tensor memory (shared.tmem)"
if
len
(
tmem_buf
.
shape
)
!=
2
:
raise
ValueError
(
f
"TCGEN5MMA expects a 2-D tensor-memory buffer, got shape
{
tmem_buf
.
shape
}
"
)
raise
ValueError
(
f
"TCGEN5MMA expects a 2-D tensor-memory buffer, got shape
{
tmem_buf
.
shape
}
"
)
m
=
int
(
tmem_buf
.
shape
[
0
])
n
=
int
(
tmem_buf
.
shape
[
1
])
...
...
@@ -382,14 +388,13 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
meta
=
self
.
get_tcgen5_mma_meta
(
m
,
n
,
k
)
if
len
(
meta
)
!=
5
:
raise
ValueError
(
f
"Unsupported TCGEN5MMA configuration: M=
{
m
}
, N=
{
n
}
, K=
{
k
}
, "
f
"A dtype=
{
self
.
a_dtype
}
, accum dtype=
{
self
.
accum_dtype
}
"
)
raise
ValueError
(
f
"Unsupported TCGEN5MMA configuration: M=
{
m
}
, N=
{
n
}
, K=
{
k
}
, A dtype=
{
self
.
a_dtype
}
, accum dtype=
{
self
.
accum_dtype
}
"
)
atom_m
,
atom_n
,
_
,
_
,
_
=
(
int
(
x
)
for
x
in
meta
)
if
m
%
atom_m
!=
0
or
n
%
atom_n
!=
0
:
raise
ValueError
(
f
"Invalid TCGEN5MMA store layout for shape (
{
m
}
,
{
n
}
) with atoms (
{
atom_m
}
,
{
atom_n
}
)"
)
raise
ValueError
(
f
"Invalid TCGEN5MMA store layout for shape (
{
m
}
,
{
n
}
) with atoms (
{
atom_m
}
,
{
atom_n
}
)"
)
def
forward
(
i
:
PrimExpr
,
j
:
PrimExpr
):
atom_idx
=
(
i
//
atom_m
)
+
(
j
//
atom_n
)
*
(
m
//
atom_m
)
...
...
@@ -422,11 +427,11 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
return
Layout
([
m
,
n
],
forward
)
def
get_tcgen5_mma_meta
(
self
,
m
:
int
,
n
:
int
,
k
:
int
):
return
_ffi_api
.
get_tcgen5_mma_meta
(
int
(
m
),
int
(
n
),
int
(
k
),
DataType
(
self
.
a_dtype
),
DataType
(
self
.
accum_dtype
))
return
_ffi_api
.
get_tcgen5_mma_meta
(
int
(
m
),
int
(
n
),
int
(
k
),
DataType
(
self
.
a_dtype
),
DataType
(
self
.
accum_dtype
))
def
get_tcgen5_instr_desc
(
self
,
atom_m
:
int
,
atom_n
:
int
,
atom_k
:
int
,
a_is_k_major
:
bool
,
b_is_k_major
:
bool
,
scale_in_a
:
int
,
scale_in_b
:
int
)
->
PrimExpr
:
def
get_tcgen5_instr_desc
(
self
,
atom_m
:
int
,
atom_n
:
int
,
atom_k
:
int
,
a_is_k_major
:
bool
,
b_is_k_major
:
bool
,
scale_in_a
:
int
,
scale_in_b
:
int
)
->
PrimExpr
:
desc
=
_ffi_api
.
get_tcgen5_instr_desc
(
atom_m
,
atom_n
,
...
...
tilelang/intrinsics/utils.py
View file @
29051439
...
...
@@ -10,7 +10,7 @@ from .mma_layout import (
mma_store_32x8_to_shared_16x16_layout
,
mma_store_32x2_to_shared_8x8_layout_fp64
,
)
from
.mfma_layout
import
(
thread_id_shared_access_64x4_to_16x16_layout_C_n_m
)
from
.mfma_layout
import
thread_id_shared_access_64x4_to_16x16_layout_C_n_m
from
.mma_layout
import
get_swizzle_layout
# noqa: F401
from
.mma_layout
import
make_mma_swizzle_layout
# noqa: F401
...
...
tilelang/intrinsics/wgmma_macro_generator.py
View file @
29051439
...
...
@@ -15,9 +15,11 @@ from tilelang.layout import (
make_linear_layout
,
)
from
tvm.runtime
import
convert
from
tilelang.intrinsics.mma_layout
import
(
shared_16x8_to_mma_32x4_layout_sr_a
,
from
tilelang.intrinsics.mma_layout
import
(
shared_16x8_to_mma_32x4_layout_sr_a
,
shared_16x16_to_mma_32x8_layout_sr_a
,
shared_16x32_to_mma_32x16_layout_sr_a
)
shared_16x32_to_mma_32x16_layout_sr_a
,
)
lift
=
convert
...
...
@@ -96,9 +98,22 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
is_m_first
:
bool
|
None
=
False
,
thread_var
:
Var
|
None
=
None
,
):
super
().
__init__
(
a_dtype
,
b_dtype
,
accum_dtype
,
a_transposed
,
b_transposed
,
block_row_warps
,
block_col_warps
,
warp_row_tiles
,
warp_col_tiles
,
chunk
,
reduce_k
,
num_elems_per_byte
,
is_m_first
,
thread_var
)
super
().
__init__
(
a_dtype
,
b_dtype
,
accum_dtype
,
a_transposed
,
b_transposed
,
block_row_warps
,
block_col_warps
,
warp_row_tiles
,
warp_col_tiles
,
chunk
,
reduce_k
,
num_elems_per_byte
,
is_m_first
,
thread_var
,
)
self
.
_initialize_wgmma_prefix
(
self
.
n_dim
)
def
_assign_a_shared_layout
(
self
,
layout
:
Layout
):
...
...
@@ -112,12 +127,12 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
def
_initialize_wgmma_prefix
(
self
,
n_dim
:
int
=
16
):
inst_m
,
inst_n
=
64
,
gcd
(
self
.
warp_col_tiles
,
256
)
assert
inst_n
%
8
==
0
,
(
f
"inst_n must be a multiple of 8, got
{
inst_n
}
"
f
"(block_col_warps=
{
self
.
block_col_warps
}
, warp_col_tiles=
{
self
.
warp_col_tiles
}
)"
)
f
"inst_n must be a multiple of 8, got
{
inst_n
}
(block_col_warps=
{
self
.
block_col_warps
}
, warp_col_tiles=
{
self
.
warp_col_tiles
}
)
"
)
# Validate inst_n: Hopper WGMMA supports n in [8, 256] and multiple of 8
assert
8
<=
inst_n
<=
256
,
(
f
"inst_n must be within [8, 256], got
{
inst_n
}
"
f
"(block_col_warps=
{
self
.
block_col_warps
}
, warp_col_tiles=
{
self
.
warp_col_tiles
}
)"
)
f
"inst_n must be within [8, 256], got
{
inst_n
}
(block_col_warps=
{
self
.
block_col_warps
}
, warp_col_tiles=
{
self
.
warp_col_tiles
}
)
"
)
# 256 bits per instruction
inst_k
=
256
//
DataType
(
self
.
a_dtype
).
bits
self
.
wgmma_inst_m
=
inst_m
...
...
@@ -160,13 +175,9 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
else
:
raise
ValueError
(
f
"Unsupported swizzle mode:
{
layout
}
"
)
def
wgmma
(
self
,
A_region
:
BufferRegion
,
B_region
:
BufferRegion
,
C_region
:
BufferRegion
,
clear_accum
:
PrimExpr
=
False
,
wg_wait
:
int
=
0
):
def
wgmma
(
self
,
A_region
:
BufferRegion
,
B_region
:
BufferRegion
,
C_region
:
BufferRegion
,
clear_accum
:
PrimExpr
=
False
,
wg_wait
:
int
=
0
):
if
is_fragment
(
A_region
):
return
self
.
wgmma_rs
(
A_region
,
B_region
,
C_region
,
clear_accum
,
wg_wait
)
...
...
@@ -195,16 +206,13 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
elems_in_bytes
=
elems_in_bits
//
8
a_swizzle_atom_elems
=
a_swizzle_mode
.
swizzle_byte_size
()
//
elems_in_bytes
b_swizzle_atom_elems
=
n_dim
if
b_swizzle_mode
.
is_none
(
)
else
b_swizzle_mode
.
swizzle_byte_size
()
//
elems_in_bytes
b_swizzle_atom_elems
=
n_dim
if
b_swizzle_mode
.
is_none
()
else
b_swizzle_mode
.
swizzle_byte_size
()
//
elems_in_bytes
accum_bits
=
DataType
(
accum_dtype
).
bits
accum_regs
=
((
m_dim
//
64
)
*
warp_cols
*
local_size_out
*
accum_bits
+
31
)
//
32
# by default, we utilize non-swizzle layout offset
a_leading_byte_offset
=
(
8
*
8
*
elems_in_bytes
)
if
a_is_k_major
else
(
8
*
m_dim
*
elems_in_bytes
)
a_stride_byte_offset
=
(
8
*
k_dim
*
elems_in_bytes
)
if
a_is_k_major
else
(
8
*
8
*
elems_in_bytes
)
a_leading_byte_offset
=
(
8
*
8
*
elems_in_bytes
)
if
a_is_k_major
else
(
8
*
m_dim
*
elems_in_bytes
)
a_stride_byte_offset
=
(
8
*
k_dim
*
elems_in_bytes
)
if
a_is_k_major
else
(
8
*
8
*
elems_in_bytes
)
if
not
a_swizzle_mode
.
is_none
():
# swizzle mode doesn't require LBO/SBO to be 1
...
...
@@ -220,19 +228,15 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
if
a_m_axis_atoms
<=
1
:
a_leading_byte_offset
=
0
else
:
a_leading_byte_offset
=
8
*
a_swizzle_mode
.
swizzle_atom_size
()
*
(
a_swizzle_mode
.
swizzle_byte_size
()
//
elems_in_bytes
)
a_leading_byte_offset
=
8
*
a_swizzle_mode
.
swizzle_atom_size
()
*
(
a_swizzle_mode
.
swizzle_byte_size
()
//
elems_in_bytes
)
if
a_m_axis_atoms
<=
1
:
a_stride_byte_offset
=
8
*
elems_in_bytes
*
m_dim
else
:
a_stride_byte_offset
=
8
*
elems_in_bytes
*
a_swizzle_atom_elems
b_leading_byte_offset
=
(
8
*
8
*
elems_in_bytes
)
if
b_is_k_major
else
(
8
*
n_dim
*
elems_in_bytes
)
b_stride_byte_offset
=
(
8
*
k_dim
*
elems_in_bytes
)
if
b_is_k_major
else
(
0
if
n_dim
==
8
else
(
8
*
8
*
elems_in_bytes
))
b_leading_byte_offset
=
(
8
*
8
*
elems_in_bytes
)
if
b_is_k_major
else
(
8
*
n_dim
*
elems_in_bytes
)
b_stride_byte_offset
=
(
8
*
k_dim
*
elems_in_bytes
)
if
b_is_k_major
else
(
0
if
n_dim
==
8
else
(
8
*
8
*
elems_in_bytes
))
if
not
b_swizzle_mode
.
is_none
():
# swizzle mode doesn't require LBO/SBO to be 1
# https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-leading-dimension-byte-offset
...
...
@@ -275,12 +279,8 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
desc_a
=
T
.
alloc_wgmma_desc
()
desc_b
=
T
.
alloc_wgmma_desc
()
T
.
initialize_wgmma_descriptor
(
desc_a
,
A_ptr
,
a_swizzle_mode
,
int
(
a_leading_byte_offset
>>
4
),
int
(
a_stride_byte_offset
>>
4
))
T
.
initialize_wgmma_descriptor
(
desc_b
,
B_ptr
,
b_swizzle_mode
,
int
(
b_leading_byte_offset
>>
4
),
int
(
b_stride_byte_offset
>>
4
))
T
.
initialize_wgmma_descriptor
(
desc_a
,
A_ptr
,
a_swizzle_mode
,
int
(
a_leading_byte_offset
>>
4
),
int
(
a_stride_byte_offset
>>
4
))
T
.
initialize_wgmma_descriptor
(
desc_b
,
B_ptr
,
b_swizzle_mode
,
int
(
b_leading_byte_offset
>>
4
),
int
(
b_stride_byte_offset
>>
4
))
T
.
warpgroup_fence_operand
(
C_buf
,
num_regs
=
accum_regs
)
T
.
warpgroup_arrive
()
...
...
@@ -291,21 +291,41 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
warp_i
=
(
warp_m
//
4
)
*
num_inst_m
+
i
warp_j
=
warp_n
*
num_inst_n
+
j
A_offset
=
(
ki
%
ak_atom_size
)
*
micro_size_k
+
warp_i
*
64
*
a_swizzle_atom_elems
+
(
ki
//
ak_atom_size
)
*
m_dim
*
a_swizzle_atom_elems
if
a_is_k_major
else
warp_i
*
64
*
k_dim
+
ki
*
a_swizzle_atom_elems
*
micro_size_k
B_offset
=
(
ki
//
bk_atom_size
)
*
n_dim
*
b_swizzle_atom_elems
+
(
ki
%
bk_atom_size
)
*
micro_size_k
+
warp_j
*
wgmma_inst_n
*
b_swizzle_atom_elems
if
b_is_k_major
else
(
ki
*
b_swizzle_atom_elems
*
micro_size_k
+
warp_j
*
wgmma_inst_n
*
(
k_dim
if
n_dim
//
b_swizzle_atom_elems
>
1
else
1
))
(
ki
%
ak_atom_size
)
*
micro_size_k
+
warp_i
*
64
*
a_swizzle_atom_elems
+
(
ki
//
ak_atom_size
)
*
m_dim
*
a_swizzle_atom_elems
if
a_is_k_major
else
warp_i
*
64
*
k_dim
+
ki
*
a_swizzle_atom_elems
*
micro_size_k
)
B_offset
=
(
(
ki
//
bk_atom_size
)
*
n_dim
*
b_swizzle_atom_elems
+
(
ki
%
bk_atom_size
)
*
micro_size_k
+
warp_j
*
wgmma_inst_n
*
b_swizzle_atom_elems
if
b_is_k_major
else
(
ki
*
b_swizzle_atom_elems
*
micro_size_k
+
warp_j
*
wgmma_inst_n
*
(
k_dim
if
n_dim
//
b_swizzle_atom_elems
>
1
else
1
)
)
)
C_offset
=
i
*
warp_cols
*
local_size_out
+
j
*
warp_cols
*
local_size_out
//
num_inst_n
# 4 warps as an unit
T
.
ptx_wgmma_ss
(
accum_dtype
,
wgmma_prefix
,
a_is_k_major
,
b_is_k_major
,
a_dtype_abbrv
,
b_dtype_abbrv
,
accum_dtype_abbrv
,
desc_a
.
data
,
(
A_offset
*
elems_in_bytes
)
>>
4
,
desc_b
.
data
,
(
B_offset
*
elems_in_bytes
)
>>
4
,
C_buf
.
data
,
C_offset
,
scale_out
,
scale_in_a
,
scale_in_b
)
T
.
ptx_wgmma_ss
(
accum_dtype
,
wgmma_prefix
,
a_is_k_major
,
b_is_k_major
,
a_dtype_abbrv
,
b_dtype_abbrv
,
accum_dtype_abbrv
,
desc_a
.
data
,
(
A_offset
*
elems_in_bytes
)
>>
4
,
desc_b
.
data
,
(
B_offset
*
elems_in_bytes
)
>>
4
,
C_buf
.
data
,
C_offset
,
scale_out
,
scale_in_a
,
scale_in_b
,
)
T
.
warpgroup_commit_batch
()
if
wg_wait
>=
0
:
...
...
@@ -314,12 +334,9 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
return
_warp_mma
(
A_ptr
,
B_ptr
,
C_buf
)
def
wgmma_rs
(
self
,
A_region
:
BufferRegion
,
B_region
:
BufferRegion
,
C_region
:
BufferRegion
,
clear_accum
:
PrimExpr
=
False
,
wg_wait
:
int
=
0
):
def
wgmma_rs
(
self
,
A_region
:
BufferRegion
,
B_region
:
BufferRegion
,
C_region
:
BufferRegion
,
clear_accum
:
PrimExpr
=
False
,
wg_wait
:
int
=
0
):
local_size_a
=
self
.
local_size_a
local_size_out
=
self
.
local_size_out
a_dtype_abbrv
=
self
.
a_dtype_abbrv
...
...
@@ -344,14 +361,10 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
b_is_k_major
=
self
.
b_transposed
b_swizzle_mode
=
self
.
_determinate_swizzle_mode
(
B_region
,
self
.
b_shared_layout
)
b_swizzle_atom_elems
=
n_dim
if
b_swizzle_mode
.
is_none
(
)
else
b_swizzle_mode
.
swizzle_byte_size
()
//
elems_in_bytes
b_leading_byte_offset
=
(
8
*
8
*
elems_in_bytes
)
if
b_is_k_major
else
(
8
*
n_dim
*
elems_in_bytes
)
b_stride_byte_offset
=
(
8
*
k_dim
*
elems_in_bytes
)
if
b_is_k_major
else
(
0
if
n_dim
==
8
else
(
8
*
8
*
elems_in_bytes
))
b_swizzle_atom_elems
=
n_dim
if
b_swizzle_mode
.
is_none
()
else
b_swizzle_mode
.
swizzle_byte_size
()
//
elems_in_bytes
b_leading_byte_offset
=
(
8
*
8
*
elems_in_bytes
)
if
b_is_k_major
else
(
8
*
n_dim
*
elems_in_bytes
)
b_stride_byte_offset
=
(
8
*
k_dim
*
elems_in_bytes
)
if
b_is_k_major
else
(
0
if
n_dim
==
8
else
(
8
*
8
*
elems_in_bytes
))
if
not
b_swizzle_mode
.
is_none
():
# swizzle mode doesn't require LBO/SBO to be 1
# https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-leading-dimension-byte-offset
...
...
@@ -390,9 +403,7 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
tx
,
warp_n
,
warp_m
=
self
.
extract_thread_binding
(
thread_binding
)
desc_b
=
T
.
alloc_wgmma_desc
()
T
.
initialize_wgmma_descriptor
(
desc_b
,
B_ptr
,
b_swizzle_mode
,
int
(
b_leading_byte_offset
>>
4
),
int
(
b_stride_byte_offset
>>
4
))
T
.
initialize_wgmma_descriptor
(
desc_b
,
B_ptr
,
b_swizzle_mode
,
int
(
b_leading_byte_offset
>>
4
),
int
(
b_stride_byte_offset
>>
4
))
T
.
warpgroup_fence_operand
(
A_buf
,
num_regs
=
a_regs
)
T
.
warpgroup_fence_operand
(
C_buf
,
num_regs
=
accum_regs
)
T
.
warpgroup_arrive
()
...
...
@@ -405,11 +416,15 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
A_offset
=
ki
*
warp_rows
*
local_size_a
+
i
*
local_size_a
B_offset
=
(
ki
//
bk_atom_size
)
*
n_dim
*
b_swizzle_atom_elems
+
warp_j
*
wgmma_inst_n
*
b_swizzle_atom_elems
+
(
ki
%
bk_atom_size
)
*
micro_size_k
if
b_is_k_major
else
(
ki
*
b_swizzle_atom_elems
*
micro_size_k
+
warp_j
*
wgmma_inst_n
*
(
k_dim
if
n_dim
//
b_swizzle_atom_elems
>
1
else
1
))
(
ki
//
bk_atom_size
)
*
n_dim
*
b_swizzle_atom_elems
+
warp_j
*
wgmma_inst_n
*
b_swizzle_atom_elems
+
(
ki
%
bk_atom_size
)
*
micro_size_k
if
b_is_k_major
else
(
ki
*
b_swizzle_atom_elems
*
micro_size_k
+
warp_j
*
wgmma_inst_n
*
(
k_dim
if
n_dim
//
b_swizzle_atom_elems
>
1
else
1
)
)
)
C_offset
=
i
*
warp_cols
*
local_size_out
+
j
*
warp_cols
*
local_size_out
//
num_inst_n
# 4 warps as an unit
T
.
ptx_wgmma_rs
(
accum_dtype
,
...
...
@@ -460,6 +475,7 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
If `local_buf` is not detected to be a fragment buffer.
"""
from
tilelang.utils
import
is_fragment
assert
matrix
in
[
"A"
],
"matrix should be A for WGMMA"
dtype
=
self
.
a_dtype
dtype_bits
=
DataType
(
dtype
).
bits
...
...
@@ -488,8 +504,7 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
# the layout of mma.sync is row.col.
# so the b matrix expected a transposed basic layout
transform_func
:
Callable
=
None
transform_func
=
transform_func_sr_a
if
is_sr_axis_order
else
lambda
i
,
j
:
transform_func_sr_a
(
j
,
i
)
transform_func
=
transform_func_sr_a
if
is_sr_axis_order
else
lambda
i
,
j
:
transform_func_sr_a
(
j
,
i
)
assert
is_fragment
(
local_buf
),
f
"local_buf must be a fragment, but got
{
local_buf
.
scope
()
}
"
...
...
@@ -531,20 +546,12 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
replicate
=
block_col_warps
if
is_sr_axis_order
:
warp_fragment
=
base_fragment
.
repeat
([
block_s
,
1
],
repeat_on_thread
=
True
,
lower_dim_first
=
False
).
replicate
(
replicate
)
block_fragment
=
warp_fragment
.
repeat
([
warp_s
,
warp_r
],
repeat_on_thread
=
False
,
lower_dim_first
=
False
)
warp_fragment
=
base_fragment
.
repeat
([
block_s
,
1
],
repeat_on_thread
=
True
,
lower_dim_first
=
False
).
replicate
(
replicate
)
block_fragment
=
warp_fragment
.
repeat
([
warp_s
,
warp_r
],
repeat_on_thread
=
False
,
lower_dim_first
=
False
)
else
:
# rs condition, transposed_a matrix
warp_fragment
=
base_fragment
.
repeat
([
1
,
block_s
],
repeat_on_thread
=
True
,
lower_dim_first
=
False
).
replicate
(
replicate
)
block_fragment
=
warp_fragment
.
repeat
([
warp_r
,
warp_s
],
repeat_on_thread
=
False
,
lower_dim_first
=
True
)
warp_fragment
=
base_fragment
.
repeat
([
1
,
block_s
],
repeat_on_thread
=
True
,
lower_dim_first
=
False
).
replicate
(
replicate
)
block_fragment
=
warp_fragment
.
repeat
([
warp_r
,
warp_s
],
repeat_on_thread
=
False
,
lower_dim_first
=
True
)
return
block_fragment
...
...
tilelang/ir.py
View file @
29051439
...
...
@@ -7,23 +7,19 @@ from tilelang import _ffi_api
@
tvm_ffi
.
register_object
(
"tl.Fill"
)
class
Fill
(
Node
,
Scriptable
):
...
class
Fill
(
Node
,
Scriptable
):
...
@
tvm_ffi
.
register_object
(
"tl.AtomicAdd"
)
class
AtomicAdd
(
Node
,
Scriptable
):
...
class
AtomicAdd
(
Node
,
Scriptable
):
...
@
tvm_ffi
.
register_object
(
"tl.Copy"
)
class
Copy
(
Node
,
Scriptable
):
...
class
Copy
(
Node
,
Scriptable
):
...
@
tvm_ffi
.
register_object
(
"tl.Conv2DIm2Col"
)
class
Conv2DIm2ColOp
(
Node
,
Scriptable
):
...
class
Conv2DIm2ColOp
(
Node
,
Scriptable
):
...
@
tvm_ffi
.
register_object
(
"tl.GemmWarpPolicy"
)
...
...
@@ -32,10 +28,8 @@ class GemmWarpPolicy(Node, Scriptable):
m_warp
:
int
n_warp
:
int
def
compute_warp_partition
(
self
,
M
:
int
,
N
:
int
,
block_size
:
int
,
target
:
Target
,
is_wgmma
:
bool
):
_ffi_api
.
GemmWarpPolicyComputeWarpPartition
(
self
,
int
(
M
),
int
(
N
),
int
(
block_size
),
target
,
is_wgmma
)
def
compute_warp_partition
(
self
,
M
:
int
,
N
:
int
,
block_size
:
int
,
target
:
Target
,
is_wgmma
:
bool
):
_ffi_api
.
GemmWarpPolicyComputeWarpPartition
(
self
,
int
(
M
),
int
(
N
),
int
(
block_size
),
target
,
is_wgmma
)
return
self
.
m_warp
,
self
.
n_warp
...
...
@@ -45,48 +39,38 @@ class GemmSPWarpPolicy(Node, Scriptable):
m_warp
:
int
n_warp
:
int
def
compute_warp_partition
(
self
,
M
:
int
,
N
:
int
,
block_size
:
int
,
target
:
Target
,
is_wgmma
:
bool
,
bits
:
int
):
_ffi_api
.
GemmSPWarpPolicyComputeWarpPartition
(
self
,
int
(
M
),
int
(
N
),
int
(
block_size
),
target
,
is_wgmma
,
bits
)
def
compute_warp_partition
(
self
,
M
:
int
,
N
:
int
,
block_size
:
int
,
target
:
Target
,
is_wgmma
:
bool
,
bits
:
int
):
_ffi_api
.
GemmSPWarpPolicyComputeWarpPartition
(
self
,
int
(
M
),
int
(
N
),
int
(
block_size
),
target
,
is_wgmma
,
bits
)
return
self
.
m_warp
,
self
.
n_warp
@
tvm_ffi
.
register_object
(
"tl.Gemm"
)
class
Gemm
(
Node
,
Scriptable
):
...
class
Gemm
(
Node
,
Scriptable
):
...
@
tvm_ffi
.
register_object
(
"tl.GemmSP"
)
class
GemmSP
(
Node
,
Scriptable
):
...
class
GemmSP
(
Node
,
Scriptable
):
...
@
tvm_ffi
.
register_object
(
"tl.FinalizeReducerOp"
)
class
FinalizeReducerOp
(
Node
,
Scriptable
):
...
class
FinalizeReducerOp
(
Node
,
Scriptable
):
...
@
tvm_ffi
.
register_object
(
"tl.ParallelOp"
)
class
ParallelOp
(
Node
,
Scriptable
):
...
class
ParallelOp
(
Node
,
Scriptable
):
...
@
tvm_ffi
.
register_object
(
"tl.ReduceOp"
)
class
ReduceOp
(
Node
,
Scriptable
):
...
class
ReduceOp
(
Node
,
Scriptable
):
...
@
tvm_ffi
.
register_object
(
"tl.CumSumOp"
)
class
CumSumOp
(
Node
,
Scriptable
):
...
class
CumSumOp
(
Node
,
Scriptable
):
...
@
tvm_ffi
.
register_object
(
"tl.RegionOp"
)
class
RegionOp
(
Node
,
Scriptable
):
...
class
RegionOp
(
Node
,
Scriptable
):
...
@
tvm_ffi
.
register_object
(
"tl.ReduceType"
)
class
ReduceType
(
Node
,
Scriptable
):
...
class
ReduceType
(
Node
,
Scriptable
):
...
Prev
1
…
15
16
17
18
19
20
21
22
23
24
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment