Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
29051439
Unverified
Commit
29051439
authored
Dec 12, 2025
by
Lei Wang
Committed by
GitHub
Dec 12, 2025
Browse files
[Lint] Phaseout Yapf format and embrace ruff format (#1417)
parent
e84b24bc
Changes
426
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
479 additions
and
652 deletions
+479
-652
tilelang/contrib/hipcc.py
tilelang/contrib/hipcc.py
+2
-7
tilelang/contrib/nvcc.py
tilelang/contrib/nvcc.py
+12
-25
tilelang/contrib/nvrtc.py
tilelang/contrib/nvrtc.py
+10
-11
tilelang/contrib/rocm.py
tilelang/contrib/rocm.py
+5
-2
tilelang/engine/lower.py
tilelang/engine/lower.py
+17
-22
tilelang/engine/param.py
tilelang/engine/param.py
+3
-0
tilelang/engine/phase.py
tilelang/engine/phase.py
+6
-12
tilelang/env.py
tilelang/env.py
+33
-39
tilelang/intrinsics/mfma_layout.py
tilelang/intrinsics/mfma_layout.py
+9
-9
tilelang/intrinsics/mfma_macro_generator.py
tilelang/intrinsics/mfma_macro_generator.py
+88
-107
tilelang/intrinsics/mma_layout.py
tilelang/intrinsics/mma_layout.py
+11
-11
tilelang/intrinsics/mma_macro_generator.py
tilelang/intrinsics/mma_macro_generator.py
+50
-93
tilelang/intrinsics/mma_sm70_layout.py
tilelang/intrinsics/mma_sm70_layout.py
+3
-5
tilelang/intrinsics/mma_sm70_macro_generator.py
tilelang/intrinsics/mma_sm70_macro_generator.py
+21
-54
tilelang/intrinsics/mma_sp_layout.py
tilelang/intrinsics/mma_sp_layout.py
+9
-18
tilelang/intrinsics/mma_sp_macro_generator.py
tilelang/intrinsics/mma_sp_macro_generator.py
+41
-74
tilelang/intrinsics/tcgen05_macro_generator.py
tilelang/intrinsics/tcgen05_macro_generator.py
+49
-44
tilelang/intrinsics/utils.py
tilelang/intrinsics/utils.py
+1
-1
tilelang/intrinsics/wgmma_macro_generator.py
tilelang/intrinsics/wgmma_macro_generator.py
+93
-86
tilelang/ir.py
tilelang/ir.py
+16
-32
No files found.
Too many changes to show.
To preserve performance only
426 of 426+
files are displayed.
Plain diff
Email patch
tilelang/contrib/hipcc.py
View file @
29051439
...
...
@@ -16,12 +16,7 @@ from tvm.base import py_str
from
tvm.contrib.rocm
import
get_rocm_arch
,
find_rocm_path
def
compile_hip
(
code
,
target_format
=
"hsaco"
,
arch
=
None
,
options
=
None
,
path_target
=
None
,
verbose
=
False
):
def
compile_hip
(
code
,
target_format
=
"hsaco"
,
arch
=
None
,
options
=
None
,
path_target
=
None
,
verbose
=
False
):
"""Compile HIP code with hipcc.
Parameters
...
...
@@ -61,7 +56,7 @@ def compile_hip(code,
file_target
=
path_target
if
path_target
else
temp_target
cmd
=
[
"hipcc"
]
cmd
+=
[
"-O3"
,
'
-c
'
]
cmd
+=
[
"-O3"
,
"
-c
"
]
if
isinstance
(
arch
,
str
):
cmd
+=
[
f
"--offload-arch=
{
arch
}
"
]
if
target_format
==
"hsaco"
:
...
...
tilelang/contrib/nvcc.py
View file @
29051439
# pylint: disable=invalid-name
# modified from apache tvm python/tvm/contrib/nvcc.py
"""Utility to invoke nvcc compiler in the system"""
from
__future__
import
annotations
import
os
...
...
@@ -18,12 +19,7 @@ from tvm.base import py_str
from
tvm.contrib
import
utils
def
compile_cuda
(
code
,
target_format
=
"ptx"
,
arch
=
None
,
options
=
None
,
path_target
=
None
,
verbose
=
False
):
def
compile_cuda
(
code
,
target_format
=
"ptx"
,
arch
=
None
,
options
=
None
,
path_target
=
None
,
verbose
=
False
):
"""Compile cuda code with NVCC from env.
Parameters
...
...
@@ -67,7 +63,7 @@ def compile_cuda(code,
temp_target
=
temp
.
relpath
(
f
"
{
file_name
}
.
{
target_format
}
"
)
pass_context
=
tvm
.
get_global_func
(
"transform.GetCurrentPassContext"
)()
kernels_output_dir
=
(
pass_context
.
config
.
get
(
"cuda.kernels_output_dir"
,
None
)
)
kernels_output_dir
=
pass_context
.
config
.
get
(
"cuda.kernels_output_dir"
,
None
)
if
kernels_output_dir
is
not
None
:
if
not
os
.
path
.
isdir
(
kernels_output_dir
):
os
.
makedirs
(
kernels_output_dir
)
...
...
@@ -114,10 +110,7 @@ def compile_cuda(code,
print
(
py_str
(
out
))
if
proc
.
returncode
!=
0
:
msg
=
f
"
{
code
}
\n
"
\
f
"Compilation error:
\n
"
\
f
"
{
py_str
(
out
)
}
\n
"
\
f
"Command:
{
' '
.
join
(
cmd
)
}
\n
"
msg
=
f
"
{
code
}
\n
Compilation error:
\n
{
py_str
(
out
)
}
\n
Command:
{
' '
.
join
(
cmd
)
}
\n
"
raise
RuntimeError
(
msg
)
with
open
(
file_target
,
"rb"
)
as
f
:
...
...
@@ -165,6 +158,7 @@ def default_compile_options(compile_flags: list[str] | None = None) -> list[str]
# (e.g., multiple "-gencode" pairs or repeated "-Xcompiler" entries).
if
compile_flags
:
import
shlex
for
flag
in
compile_flags
:
# Split each string like a shell would, preserving quoted args
tokens
=
shlex
.
split
(
flag
)
if
isinstance
(
flag
,
str
)
else
[
str
(
flag
)]
...
...
@@ -172,9 +166,7 @@ def default_compile_options(compile_flags: list[str] | None = None) -> list[str]
return
options
def
get_ptx_from_source
(
code
:
str
,
compile_flags
:
list
[
str
]
|
None
=
None
,
verbose
:
bool
=
False
)
->
str
:
def
get_ptx_from_source
(
code
:
str
,
compile_flags
:
list
[
str
]
|
None
=
None
,
verbose
:
bool
=
False
)
->
str
:
"""
Compile CUDA C++ source to PTX using NVCC and return as text.
...
...
@@ -212,9 +204,7 @@ def _find_tool(name: str) -> str | None:
return
None
def
get_sass_from_source
(
code
:
str
,
compile_flags
:
list
[
str
]
|
None
=
None
,
verbose
:
bool
=
False
)
->
str
:
def
get_sass_from_source
(
code
:
str
,
compile_flags
:
list
[
str
]
|
None
=
None
,
verbose
:
bool
=
False
)
->
str
:
"""
Compile CUDA C++ source to CUBIN and disassemble to SASS.
...
...
@@ -246,9 +236,7 @@ def get_sass_from_source(code: str,
cand_nvdisasm
=
_find_tool
(
"nvdisasm"
)
cand_cuobjdump
=
_find_tool
(
"cuobjdump"
)
if
not
cand_nvdisasm
and
not
cand_cuobjdump
:
raise
RuntimeError
(
"Cannot find 'nvdisasm' or 'cuobjdump'. Please ensure CUDA toolkit is installed and in PATH."
)
raise
RuntimeError
(
"Cannot find 'nvdisasm' or 'cuobjdump'. Please ensure CUDA toolkit is installed and in PATH."
)
last_err
:
str
|
None
=
None
try
:
# Attempt nvdisasm first
...
...
@@ -268,8 +256,7 @@ def get_sass_from_source(code: str,
return
text
last_err
=
f
"
{
tool_name
}
rc=
{
proc
.
returncode
}
, output:
\n
{
text
}
"
# If we reach here, all attempts failed
raise
RuntimeError
(
f
"SASS disassembly failed. Tried tools: "
f
"
{
', '
.
join
(
name
for
name
,
_
in
tools_to_try
)
}
\n
{
last_err
or
''
}
"
)
raise
RuntimeError
(
f
"SASS disassembly failed. Tried tools:
{
', '
.
join
(
name
for
name
,
_
in
tools_to_try
)
}
\n
{
last_err
or
''
}
"
)
finally
:
with
contextlib
.
suppress
(
Exception
):
os
.
remove
(
cubin_path
)
...
...
@@ -438,8 +425,7 @@ def get_target_compute_version(target=None):
if
tvm
.
cuda
(
0
).
exist
:
return
tvm
.
cuda
(
0
).
compute_version
raise
ValueError
(
"No CUDA architecture was specified or GPU detected."
"Try specifying it by adding '-arch=sm_xx' to your target."
)
raise
ValueError
(
"No CUDA architecture was specified or GPU detected.Try specifying it by adding '-arch=sm_xx' to your target."
)
def
parse_compute_version
(
compute_version
)
->
tuple
[
int
,
int
]:
...
...
@@ -524,7 +510,8 @@ def have_tensorcore(compute_version=None, target=None):
warnings
.
warn
(
"Tensorcore will be disabled due to no CUDA architecture specified."
"Try specifying it by adding '-arch=sm_xx' to your target."
,
stacklevel
=
2
)
stacklevel
=
2
,
)
return
False
compute_version
=
target
.
attrs
[
"arch"
]
# Compute version will be in the form "sm_{major}{minor}"
...
...
tilelang/contrib/nvrtc.py
View file @
29051439
...
...
@@ -11,11 +11,13 @@ def get_nvrtc_version() -> tuple[int, int]:
return
(
major
,
minor
)
def
compile_cuda
(
code
:
str
,
target_format
:
Literal
[
"ptx"
,
"cubin"
]
=
"ptx"
,
arch
:
int
|
None
=
None
,
options
:
str
|
list
[
str
]
|
None
=
None
,
verbose
:
bool
=
False
)
->
bytearray
:
def
compile_cuda
(
code
:
str
,
target_format
:
Literal
[
"ptx"
,
"cubin"
]
=
"ptx"
,
arch
:
int
|
None
=
None
,
options
:
str
|
list
[
str
]
|
None
=
None
,
verbose
:
bool
=
False
,
)
->
bytearray
:
"""Compile cuda code with NVRTC.
Parameters
...
...
@@ -43,8 +45,7 @@ def compile_cuda(code: str,
if
arch
is
None
:
# If None, then it will use `tvm.target.Target.current().arch`.
# Target arch could be a str like "80", "90", "90a", etc.
major
,
minor
=
parse_compute_version
(
get_target_compute_version
(
Target
.
current
(
allow_none
=
True
)))
major
,
minor
=
parse_compute_version
(
get_target_compute_version
(
Target
.
current
(
allow_none
=
True
)))
arch
=
major
*
10
+
minor
prefix
=
"compute"
if
target_format
==
"ptx"
else
"sm"
suffix
=
"a"
if
arch
>=
90
else
""
...
...
@@ -77,8 +78,7 @@ def compile_cuda(code: str,
compile_result
=
nvrtc
.
nvrtcCompileProgram
(
program
,
len
(
options_bytes
),
options_bytes
)[
0
]
if
compile_result
!=
nvrtc
.
nvrtcResult
.
NVRTC_SUCCESS
:
msg
=
f
"
{
code
}
\n
"
\
f
"Compilation error:
\n
"
msg
=
f
"
{
code
}
\n
Compilation error:
\n
"
if
verbose
:
result
,
log_size
=
nvrtc
.
nvrtcGetProgramLogSize
(
program
)
assert
result
==
nvrtc
.
nvrtcResult
.
NVRTC_SUCCESS
,
f
"Failed to get program log size:
{
result
}
"
...
...
@@ -105,7 +105,6 @@ def compile_cuda(code: str,
assert
result
==
nvrtc
.
nvrtcResult
.
NVRTC_SUCCESS
,
f
"Failed to get PTX:
{
result
}
"
# Destroy handler
assert
nvrtc
.
nvrtcDestroyProgram
(
program
)[
0
]
==
nvrtc
.
nvrtcResult
.
NVRTC_SUCCESS
,
f
"Failed to destroy program:
{
result
}
"
assert
nvrtc
.
nvrtcDestroyProgram
(
program
)[
0
]
==
nvrtc
.
nvrtcResult
.
NVRTC_SUCCESS
,
f
"Failed to destroy program:
{
result
}
"
return
result_bytes
tilelang/contrib/rocm.py
View file @
29051439
...
...
@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
"""Utility for ROCm backend"""
# ruff: noqa
import
re
import
subprocess
...
...
@@ -255,9 +256,11 @@ def get_rocm_arch(rocm_path="/opt/rocm"):
gpu_arch
=
match
.
group
(
1
)
return
gpu_arch
except
subprocess
.
CalledProcessError
:
print
(
f
"Unable to execute rocminfo command,
\
print
(
f
"Unable to execute rocminfo command,
\
please ensure ROCm is installed and you have an AMD GPU on your system.
\
using default
{
gpu_arch
}
."
)
using default
{
gpu_arch
}
."
)
return
gpu_arch
...
...
tilelang/engine/lower.py
View file @
29051439
"""The compiler for TL programs."""
from
__future__
import
annotations
import
os
...
...
@@ -28,14 +29,13 @@ def is_cpu_device_backend(target: Target):
def
has_device_kernel_launch
(
attrs
)
->
bool
:
"""Check if the attributes indicate a device kernel launch."""
return
bool
(
attrs
and
"calling_conv"
in
attrs
and
attrs
[
"calling_conv"
]
==
CallingConv
.
DEVICE_KERNEL_LAUNCH
)
return
bool
(
attrs
and
"calling_conv"
in
attrs
and
attrs
[
"calling_conv"
]
==
CallingConv
.
DEVICE_KERNEL_LAUNCH
)
def
is_device_call_c_device
(
func
:
tir
.
PrimFunc
):
attrs
=
func
.
attrs
calling_conv
=
attrs
.
get
(
"calling_conv"
,
CallingConv
.
DEFAULT
)
is_cpacked
=
(
calling_conv
==
CallingConv
.
C_PACKED_FUNC
)
is_cpacked
=
calling_conv
==
CallingConv
.
C_PACKED_FUNC
# Check if it's a C target
if
"target"
in
attrs
and
attrs
[
"target"
].
kind
.
name
==
"c"
and
not
is_cpacked
:
...
...
@@ -141,16 +141,16 @@ def extrac_params(func: tir.PrimFunc) -> list[KernelParam]:
if
var
in
func
.
buffer_map
:
tensor_types
.
append
(
KernelParam
.
from_buffer
(
func
.
buffer_map
[
var
]))
else
:
if
var
.
dtype
==
'
handle
'
:
if
var
.
dtype
==
"
handle
"
:
raise
ValueError
(
f
'Handle parameter
{
var
}
must be mapped to a buffer.
\n
'
f
'Please use T.tensor(
{
var
.
name
}
, shape=..., dtype=...) to map it to a buffer.'
)
f
"Handle parameter
{
var
}
must be mapped to a buffer.
\n
"
f
"Please use T.tensor(
{
var
.
name
}
, shape=..., dtype=...) to map it to a buffer."
)
tensor_types
.
append
(
KernelParam
.
from_var
(
var
))
return
tensor_types
def
canon_target_host
(
target
:
str
|
Target
,
target_host
:
str
|
Target
|
None
):
if
not
target_host
:
target_host
=
"llvm"
if
tvm
.
runtime
.
enabled
(
"llvm"
)
else
"c"
...
...
@@ -195,11 +195,9 @@ def device_codegen_without_compile(device_mod: tvm.IRModule, target: Target) ->
device_mod
=
tilelang
.
transform
.
LowerIntrin
()(
device_mod
)
device_mod
=
tir
.
transform
.
Simplify
()(
device_mod
)
if
target
.
kind
.
name
==
"cuda"
:
device_mod
=
tvm
.
ffi
.
get_global_func
(
"target.build.tilelang_cuda_without_compile"
)(
device_mod
,
target
)
device_mod
=
tvm
.
ffi
.
get_global_func
(
"target.build.tilelang_cuda_without_compile"
)(
device_mod
,
target
)
elif
target
.
kind
.
name
==
"hip"
:
device_mod
=
tvm
.
ffi
.
get_global_func
(
"target.build.tilelang_hip_without_compile"
)(
device_mod
,
target
)
device_mod
=
tvm
.
ffi
.
get_global_func
(
"target.build.tilelang_hip_without_compile"
)(
device_mod
,
target
)
elif
target
.
kind
.
name
==
"c"
:
device_mod
=
tvm
.
ffi
.
get_global_func
(
"target.build.tilelang_cpp"
)(
device_mod
,
target
)
elif
target
.
kind
.
name
==
"llvm"
:
...
...
@@ -222,12 +220,12 @@ def lower(
enable_host_codegen
=
False
,
enable_device_compile
=
False
,
)
->
CompiledArtifact
:
'''
enable_host_codegen: whether to enable host codegen, default is False, as we have our
own host codegen implementation in jit.
enable_device_compile: whether to enable device codegen, default is False, as we have our
own device codegen implementation in jit.
'''
"""
enable_host_codegen: whether to enable host codegen, default is False, as we have our
own host codegen implementation in jit.
enable_device_compile: whether to enable device codegen, default is False, as we have our
own device codegen implementation in jit.
"""
mod
=
func_or_mod
params
=
None
...
...
@@ -259,14 +257,11 @@ def lower(
host_mod
=
tir
.
transform
.
Filter
(
_is_host_call
)(
mod
)
device_mod
=
tir
.
transform
.
Filter
(
_is_device_call
)(
mod
)
codegen_mod
=
device_codegen
(
device_mod
,
target
)
if
enable_device_compile
else
device_codegen_without_compile
(
device_mod
,
target
)
codegen_mod
=
device_codegen
(
device_mod
,
target
)
if
enable_device_compile
else
device_codegen_without_compile
(
device_mod
,
target
)
if
enable_host_codegen
:
host_mod
=
host_codegen
(
host_mod
,
target_host
)
host_mod
.
import_module
(
codegen_mod
)
return
CompiledArtifact
(
host_mod
,
device_mod
,
params
,
codegen_mod
.
inspect_source
(),
rt_mod
=
host_mod
)
return
CompiledArtifact
(
host_mod
,
device_mod
,
params
,
codegen_mod
.
inspect_source
(),
rt_mod
=
host_mod
)
return
CompiledArtifact
(
host_mod
,
device_mod
,
params
,
codegen_mod
.
inspect_source
())
tilelang/engine/param.py
View file @
29051439
"""The profiler and convert to torch utils"""
from
__future__
import
annotations
from
dataclasses
import
dataclass
...
...
@@ -14,6 +15,7 @@ class KernelParam:
Represents parameters for a kernel operation, storing dtype and shape information.
Used to describe tensor or scalar parameters in TVM/PyTorch interop.
"""
dtype
:
torch
.
dtype
# PyTorch data type of the parameter
shape
:
list
[
int
|
Var
]
# List of dimensions, can be integers or TVM variables
...
...
@@ -109,6 +111,7 @@ class CompiledArtifact:
Represents a compiled kernel artifact containing both host and device code.
Stores all necessary components for kernel execution in the TVM runtime.
"""
host_mod
:
tvm
.
IRModule
# Host-side TVM IR module for managing kernel execution
device_mod
:
tvm
.
IRModule
# Device-side TVM IR module containing the actual kernel code
params
:
list
[
KernelParam
]
# List of parameters (tensors/scalars) used by the kernel
...
...
tilelang/engine/phase.py
View file @
29051439
...
...
@@ -6,8 +6,7 @@ from tilelang.transform import PassContext
from
tilelang.contrib.nvcc
import
have_tma
,
is_hopper
def
allow_warp_specialized
(
pass_ctx
:
PassContext
|
None
=
None
,
target
:
Target
|
None
=
None
)
->
bool
:
def
allow_warp_specialized
(
pass_ctx
:
PassContext
|
None
=
None
,
target
:
Target
|
None
=
None
)
->
bool
:
# avoid circular import
from
tilelang.jit.adapter.utils
import
is_cuda_target
...
...
@@ -19,8 +18,7 @@ def allow_warp_specialized(pass_ctx: PassContext | None = None,
return
not
disable_warp_specialized
def
allow_tma_and_warp_specialized
(
pass_ctx
:
PassContext
|
None
=
None
,
target
:
Target
|
None
=
None
)
->
bool
:
def
allow_tma_and_warp_specialized
(
pass_ctx
:
PassContext
|
None
=
None
,
target
:
Target
|
None
=
None
)
->
bool
:
if
pass_ctx
is
None
:
pass_ctx
=
tilelang
.
transform
.
get_pass_context
()
if
not
have_tma
(
target
):
...
...
@@ -47,12 +45,10 @@ def allow_global_thread_synchronization(pass_ctx: PassContext | None = None) ->
return
enable_global_thread_sync
def
should_enable_aggressive_merge
(
pass_ctx
:
PassContext
|
None
=
None
,
target
:
Target
|
None
=
None
)
->
bool
:
def
should_enable_aggressive_merge
(
pass_ctx
:
PassContext
|
None
=
None
,
target
:
Target
|
None
=
None
)
->
bool
:
if
pass_ctx
is
None
:
pass_ctx
=
tilelang
.
transform
.
get_pass_context
()
enable_aggressive_merge
=
bool
(
pass_ctx
.
config
.
get
(
tilelang
.
PassConfigKey
.
TL_ENABLE_AGGRESSIVE_SHARED_MEMORY_MERGE
,
False
))
enable_aggressive_merge
=
bool
(
pass_ctx
.
config
.
get
(
tilelang
.
PassConfigKey
.
TL_ENABLE_AGGRESSIVE_SHARED_MEMORY_MERGE
,
False
))
if
allow_warp_specialized
(
pass_ctx
=
pass_ctx
,
target
=
target
):
# This is a workaround to avoid the bug in the MergeSharedMemoryAllocations pass
# when warp specialization is enabled, as different warp threads may access different
...
...
@@ -88,7 +84,7 @@ def get_layout_visual_formats(pass_ctx: PassContext | None = None) -> list[str]:
return
[
"txt"
,
"png"
,
"pdf"
,
"svg"
]
if
","
in
formats_str
:
formats_list
=
[
f
.
strip
()
for
f
in
formats_str
.
split
(
','
)]
formats_list
=
[
f
.
strip
()
for
f
in
formats_str
.
split
(
","
)]
else
:
formats_list
=
[
formats_str
]
...
...
@@ -257,9 +253,7 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
# MergeSharedMemoryAllocations must be applied after SplitHostDevice
# because the merged allocation site is at the beginning of each device function
enable_aggressive_merge
=
should_enable_aggressive_merge
(
pass_ctx
=
pass_ctx
,
target
=
target
)
mod
=
tilelang
.
transform
.
MergeSharedMemoryAllocations
(
enable_aggressive_merge
=
enable_aggressive_merge
)(
mod
)
mod
=
tilelang
.
transform
.
MergeSharedMemoryAllocations
(
enable_aggressive_merge
=
enable_aggressive_merge
)(
mod
)
mod
=
tilelang
.
transform
.
ThreadSync
(
"shared"
)(
mod
)
mod
=
tilelang
.
transform
.
ThreadSync
(
"shared.dyn"
)(
mod
)
# Inject PTX async copy must behind the thread sync pass
...
...
tilelang/env.py
View file @
29051439
...
...
@@ -10,36 +10,34 @@ from dataclasses import dataclass
logger
=
logging
.
getLogger
(
__name__
)
# SETUP ENVIRONMENT VARIABLES
CUTLASS_NOT_FOUND_MESSAGE
=
(
"CUTLASS is not installed or found in the expected path"
)
CUTLASS_NOT_FOUND_MESSAGE
=
"CUTLASS is not installed or found in the expected path"
", which may lead to compilation bugs when utilize tilelang backend."
COMPOSABLE_KERNEL_NOT_FOUND_MESSAGE
=
(
"Composable Kernel is not installed or found in the expected path"
)
COMPOSABLE_KERNEL_NOT_FOUND_MESSAGE
=
"Composable Kernel is not installed or found in the expected path"
", which may lead to compilation bugs when utilize tilelang backend."
TL_TEMPLATE_NOT_FOUND_MESSAGE
=
(
"TileLang is not installed or found in the expected path"
)
TL_TEMPLATE_NOT_FOUND_MESSAGE
=
"TileLang is not installed or found in the expected path"
", which may lead to compilation bugs when utilize tilelang backend."
TVM_LIBRARY_NOT_FOUND_MESSAGE
=
(
"TVM is not installed or found in the expected path"
)
TVM_LIBRARY_NOT_FOUND_MESSAGE
=
"TVM is not installed or found in the expected path"
TL_ROOT
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
# Only expose the internal lib directory to sys.path to avoid shadowing
# common top-level module names (e.g., utils, analysis) from user projects.
TL_LIBS
=
[
os
.
path
.
join
(
TL_ROOT
,
'
lib
'
)]
TL_LIBS
=
[
os
.
path
.
join
(
TL_ROOT
,
"
lib
"
)]
TL_LIBS
=
[
i
for
i
in
TL_LIBS
if
os
.
path
.
exists
(
i
)]
DEV
=
False
THIRD_PARTY_ROOT
=
os
.
path
.
join
(
TL_ROOT
,
'
3rdparty
'
)
THIRD_PARTY_ROOT
=
os
.
path
.
join
(
TL_ROOT
,
"
3rdparty
"
)
if
not
os
.
path
.
exists
(
THIRD_PARTY_ROOT
):
DEV
=
True
tl_dev_root
=
os
.
path
.
dirname
(
TL_ROOT
)
dev_lib_root
=
os
.
path
.
join
(
tl_dev_root
,
'
build
'
)
dev_lib_root
=
os
.
path
.
join
(
tl_dev_root
,
"
build
"
)
# In dev builds, place artifacts under build/lib and point search path there
# to avoid adding the entire build root to sys.path.
TL_LIBS
=
[
os
.
path
.
join
(
dev_lib_root
,
'
lib
'
),
os
.
path
.
join
(
dev_lib_root
,
'
tvm
'
)]
THIRD_PARTY_ROOT
=
os
.
path
.
join
(
tl_dev_root
,
'
3rdparty
'
)
logger
.
warning
(
f
'
Loading tilelang libs from dev root:
{
dev_lib_root
}
'
)
TL_LIBS
=
[
os
.
path
.
join
(
dev_lib_root
,
"
lib
"
),
os
.
path
.
join
(
dev_lib_root
,
"
tvm
"
)]
THIRD_PARTY_ROOT
=
os
.
path
.
join
(
tl_dev_root
,
"
3rdparty
"
)
logger
.
warning
(
f
"
Loading tilelang libs from dev root:
{
dev_lib_root
}
"
)
assert
TL_LIBS
and
all
(
os
.
path
.
exists
(
i
)
for
i
in
TL_LIBS
),
f
'tilelang lib root do not exists:
{
TL_LIBS
}
'
assert
TL_LIBS
and
all
(
os
.
path
.
exists
(
i
)
for
i
in
TL_LIBS
),
f
"tilelang lib root do not exists:
{
TL_LIBS
}
"
for
lib
in
TL_LIBS
:
if
lib
not
in
sys
.
path
:
...
...
@@ -52,7 +50,7 @@ def _find_cuda_home() -> str:
Adapted from https://github.com/pytorch/pytorch/blob/main/torch/utils/cpp_extension.py
"""
# Guess #1
cuda_home
=
os
.
environ
.
get
(
'
CUDA_HOME
'
)
or
os
.
environ
.
get
(
'
CUDA_PATH
'
)
cuda_home
=
os
.
environ
.
get
(
"
CUDA_HOME
"
)
or
os
.
environ
.
get
(
"
CUDA_PATH
"
)
if
cuda_home
is
None
:
# Guess #2
nvcc_path
=
shutil
.
which
(
"nvcc"
)
...
...
@@ -70,15 +68,15 @@ def _find_cuda_home() -> str:
else
:
# Guess #3
if
sys
.
platform
==
'
win32
'
:
cuda_homes
=
glob
.
glob
(
'
C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v*.*
'
)
cuda_home
=
''
if
len
(
cuda_homes
)
==
0
else
cuda_homes
[
0
]
if
sys
.
platform
==
"
win32
"
:
cuda_homes
=
glob
.
glob
(
"
C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v*.*
"
)
cuda_home
=
""
if
len
(
cuda_homes
)
==
0
else
cuda_homes
[
0
]
else
:
# Linux/macOS
if
os
.
path
.
exists
(
'
/usr/local/cuda
'
):
cuda_home
=
'
/usr/local/cuda
'
elif
os
.
path
.
exists
(
'
/opt/nvidia/hpc_sdk/Linux_x86_64
'
):
cuda_home
=
'
/opt/nvidia/hpc_sdk/Linux_x86_64
'
if
os
.
path
.
exists
(
"
/usr/local/cuda
"
):
cuda_home
=
"
/usr/local/cuda
"
elif
os
.
path
.
exists
(
"
/opt/nvidia/hpc_sdk/Linux_x86_64
"
):
cuda_home
=
"
/opt/nvidia/hpc_sdk/Linux_x86_64
"
# Validate found path
if
cuda_home
is
None
or
not
os
.
path
.
exists
(
cuda_home
):
...
...
@@ -89,13 +87,13 @@ def _find_cuda_home() -> str:
def
_find_rocm_home
()
->
str
:
"""Find the ROCM install path."""
rocm_home
=
os
.
environ
.
get
(
'
ROCM_PATH
'
)
or
os
.
environ
.
get
(
'
ROCM_HOME
'
)
rocm_home
=
os
.
environ
.
get
(
"
ROCM_PATH
"
)
or
os
.
environ
.
get
(
"
ROCM_HOME
"
)
if
rocm_home
is
None
:
rocmcc_path
=
shutil
.
which
(
"hipcc"
)
if
rocmcc_path
is
not
None
:
rocm_home
=
os
.
path
.
dirname
(
os
.
path
.
dirname
(
rocmcc_path
))
else
:
rocm_home
=
'
/opt/rocm
'
rocm_home
=
"
/opt/rocm
"
if
not
os
.
path
.
exists
(
rocm_home
):
rocm_home
=
None
return
rocm_home
if
rocm_home
is
not
None
else
""
...
...
@@ -104,6 +102,7 @@ def _find_rocm_home() -> str:
# Cache control
class
CacheState
:
"""Class to manage global kernel caching state."""
_enabled
=
True
@
classmethod
...
...
@@ -230,13 +229,11 @@ class Environment:
TILELANG_TMP_DIR
=
EnvVar
(
"TILELANG_TMP_DIR"
,
os
.
path
.
join
(
TILELANG_CACHE_DIR
.
get
(),
"tmp"
))
# Kernel Build options
TILELANG_PRINT_ON_COMPILATION
=
EnvVar
(
"TILELANG_PRINT_ON_COMPILATION"
,
"1"
)
# print kernel name on compile
TILELANG_PRINT_ON_COMPILATION
=
EnvVar
(
"TILELANG_PRINT_ON_COMPILATION"
,
"1"
)
# print kernel name on compile
TILELANG_DISABLE_CACHE
=
EnvVar
(
"TILELANG_DISABLE_CACHE"
,
"0"
)
# disable kernel cache, usually for unit testing / debugging, high priority
TILELANG_CLEAR_CACHE
=
EnvVar
(
"TILELANG_CLEAR_CACHE"
,
"0"
)
# DEPRECATED! clear cache automatically if set
"TILELANG_DISABLE_CACHE"
,
"0"
)
# disable kernel cache, usually for unit testing / debugging, high priority
TILELANG_CLEAR_CACHE
=
EnvVar
(
"TILELANG_CLEAR_CACHE"
,
"0"
)
# DEPRECATED! clear cache automatically if set
# Kernel selection options
# Default to GEMM v2; set to "1"/"true"/"yes"/"on" to force v1
...
...
@@ -244,12 +241,9 @@ class Environment:
# Auto-tuning settings
TILELANG_AUTO_TUNING_DISABLE_CACHE
=
EnvVar
(
"TILELANG_AUTO_TUNING_DISABLE_CACHE"
,
"0"
)
TILELANG_AUTO_TUNING_CPU_UTILITIES
=
EnvVar
(
"TILELANG_AUTO_TUNING_CPU_UTILITIES"
,
"0.9"
)
# percent of CPUs used
TILELANG_AUTO_TUNING_CPU_COUNTS
=
EnvVar
(
"TILELANG_AUTO_TUNING_CPU_COUNTS"
,
"-1"
)
# -1 means auto
TILELANG_AUTO_TUNING_MAX_CPU_COUNT
=
EnvVar
(
"TILELANG_AUTO_TUNING_MAX_CPU_COUNT"
,
"-1"
)
# -1 means no limit
TILELANG_AUTO_TUNING_CPU_UTILITIES
=
EnvVar
(
"TILELANG_AUTO_TUNING_CPU_UTILITIES"
,
"0.9"
)
# percent of CPUs used
TILELANG_AUTO_TUNING_CPU_COUNTS
=
EnvVar
(
"TILELANG_AUTO_TUNING_CPU_COUNTS"
,
"-1"
)
# -1 means auto
TILELANG_AUTO_TUNING_MAX_CPU_COUNT
=
EnvVar
(
"TILELANG_AUTO_TUNING_MAX_CPU_COUNT"
,
"-1"
)
# -1 means no limit
# TVM integration
SKIP_LOADING_TILELANG_SO
=
EnvVar
(
"SKIP_LOADING_TILELANG_SO"
,
"0"
)
...
...
@@ -323,18 +317,18 @@ def prepend_pythonpath(path):
if
env
.
TVM_IMPORT_PYTHON_PATH
is
not
None
:
prepend_pythonpath
(
env
.
TVM_IMPORT_PYTHON_PATH
)
else
:
tvm_path
=
os
.
path
.
join
(
THIRD_PARTY_ROOT
,
'
tvm
'
,
'
python
'
)
tvm_path
=
os
.
path
.
join
(
THIRD_PARTY_ROOT
,
"
tvm
"
,
"
python
"
)
assert
os
.
path
.
exists
(
tvm_path
),
tvm_path
if
tvm_path
not
in
sys
.
path
:
prepend_pythonpath
(
tvm_path
)
env
.
TVM_IMPORT_PYTHON_PATH
=
tvm_path
# By default, the built TVM-related libraries are stored in TL_LIBS.
if
os
.
environ
.
get
(
"TVM_LIBRARY_PATH"
)
is
None
:
os
.
environ
[
'
TVM_LIBRARY_PATH
'
]
=
env
.
TVM_LIBRARY_PATH
=
os
.
pathsep
.
join
(
TL_LIBS
)
os
.
environ
[
"
TVM_LIBRARY_PATH
"
]
=
env
.
TVM_LIBRARY_PATH
=
os
.
pathsep
.
join
(
TL_LIBS
)
# Initialize CUTLASS paths
if
os
.
environ
.
get
(
"TL_CUTLASS_PATH"
,
None
)
is
None
:
cutlass_inc_path
=
os
.
path
.
join
(
THIRD_PARTY_ROOT
,
'
cutlass
'
,
'
include
'
)
cutlass_inc_path
=
os
.
path
.
join
(
THIRD_PARTY_ROOT
,
"
cutlass
"
,
"
include
"
)
if
os
.
path
.
exists
(
cutlass_inc_path
):
os
.
environ
[
"TL_CUTLASS_PATH"
]
=
env
.
CUTLASS_INCLUDE_DIR
=
cutlass_inc_path
else
:
...
...
@@ -342,7 +336,7 @@ if os.environ.get("TL_CUTLASS_PATH", None) is None:
# Initialize COMPOSABLE_KERNEL paths
if
os
.
environ
.
get
(
"TL_COMPOSABLE_KERNEL_PATH"
,
None
)
is
None
:
ck_inc_path
=
os
.
path
.
join
(
THIRD_PARTY_ROOT
,
'
composable_kernel
'
,
'
include
'
)
ck_inc_path
=
os
.
path
.
join
(
THIRD_PARTY_ROOT
,
"
composable_kernel
"
,
"
include
"
)
if
os
.
path
.
exists
(
ck_inc_path
):
os
.
environ
[
"TL_COMPOSABLE_KERNEL_PATH"
]
=
env
.
COMPOSABLE_KERNEL_INCLUDE_DIR
=
ck_inc_path
else
:
...
...
tilelang/intrinsics/mfma_layout.py
View file @
29051439
...
...
@@ -4,7 +4,7 @@ import tilelang.language as T
def
shared_16x4_to_local_64x1_layout_A
(
i
,
j
):
thread_id
=
(
j
*
16
+
i
)
thread_id
=
j
*
16
+
i
return
thread_id
,
convert
(
0
)
...
...
@@ -15,7 +15,7 @@ def thread_id_shared_access_64x1_to_16x4_layout_A(thread_id, local_id):
def
shared_4x16_to_local_64x1_layout_B
(
i
,
j
):
thread_id
=
(
i
*
16
+
j
)
thread_id
=
i
*
16
+
j
return
thread_id
,
convert
(
0
)
...
...
@@ -27,7 +27,7 @@ def thread_id_shared_access_64x1_to_4x16_layout_B(thread_id, local_id):
def
shared_16x16_to_local_64x4_layout_C
(
i
,
j
):
thread_id
=
j
+
(
i
//
4
)
*
16
local
=
(
i
%
4
)
local
=
i
%
4
return
thread_id
,
local
...
...
@@ -45,7 +45,7 @@ def thread_id_shared_access_64x4_to_16x16_layout_A(thread_id, local_id):
def
shared_16x16_to_local_64x4_layout_A
(
i
,
j
):
thread_id
=
i
+
16
*
(
j
//
4
)
local
=
(
j
%
4
)
local
=
j
%
4
return
thread_id
,
local
...
...
@@ -57,7 +57,7 @@ def thread_id_shared_access_64x4_to_16x16_layout_B(thread_id, local_id):
def
shared_16x16_to_local_64x4_layout_B
(
i
,
j
):
thread_id
=
j
+
(
i
//
4
)
*
16
local
=
(
i
%
4
)
local
=
i
%
4
return
thread_id
,
local
...
...
@@ -87,7 +87,7 @@ def thread_id_shared_access_64x8_to_16x32_layout_A(thread_id, local_id):
def
shared_16x32_to_local_64x8_layout_A
(
i
,
j
):
thread_id
=
i
+
16
*
(
j
//
8
)
local
=
(
j
%
8
)
local
=
j
%
8
return
thread_id
,
local
...
...
@@ -99,7 +99,7 @@ def thread_id_shared_access_64x8_to_16x32_layout_B(thread_id, local_id):
def
shared_16x32_to_local_64x8_layout_B
(
i
,
j
):
thread_id
=
j
+
(
i
//
8
)
*
16
local
=
(
i
%
8
)
local
=
i
%
8
return
thread_id
,
local
...
...
@@ -111,7 +111,7 @@ def thread_id_shared_access_64x16_to_16x64_layout_A(thread_id, local_id):
def
shared_16x64_to_local_64x16_layout_A
(
i
,
j
):
thread_id
=
i
+
16
*
(
j
//
16
)
local
=
(
j
%
16
)
local
=
j
%
16
return
thread_id
,
local
...
...
@@ -123,7 +123,7 @@ def thread_id_shared_access_64x16_to_16x64_layout_B(thread_id, local_id):
def
shared_16x64_to_local_64x16_layout_B
(
i
,
j
):
thread_id
=
i
+
16
*
(
j
//
16
)
local
=
(
j
%
16
)
local
=
j
%
16
return
thread_id
,
local
...
...
tilelang/intrinsics/mfma_macro_generator.py
View file @
29051439
...
...
@@ -6,7 +6,7 @@ from tvm import tir
from
tvm.ir
import
Range
from
tvm.tir
import
PrimExpr
,
IndexMap
,
Buffer
,
Var
,
BufferRegion
,
BufferLoad
from
tvm.runtime
import
convert
from
.utils
import
(
mfma_store_index_map
)
from
.utils
import
mfma_store_index_map
from
typing
import
Literal
,
Callable
from
tilelang.utils
import
is_fragment
...
...
@@ -101,7 +101,7 @@ class MatrixCoreIntrinEmitter:
self
.
warp_rows
=
warp_row_tiles
//
self
.
micro_size_x
self
.
warp_cols
=
warp_col_tiles
//
self
.
micro_size_y
self
.
reduce_k
=
reduce_k
self
.
threads
=
(
self
.
WARP_SIZE
*
(
block_row_warps
*
block_col_warps
)
*
reduce_k
)
self
.
threads
=
self
.
WARP_SIZE
*
(
block_row_warps
*
block_col_warps
)
*
reduce_k
self
.
num_elems_per_byte
=
num_elems_per_byte
self
.
thread_var
=
thread_var
...
...
@@ -132,12 +132,7 @@ class MatrixCoreIntrinEmitter:
def
_initialize_mfma_prefix
(
self
,
k_dim
=
16
):
in_dtype
,
out_dtype
=
self
.
a_dtype
,
self
.
accum_dtype
M_DIM
,
N_DIM
=
self
.
M_DIM
,
self
.
N_DIM
out_dtype_abbrv
=
{
"float16"
:
"f16"
,
"float32"
:
"f32"
,
"int8"
:
"i8"
,
"int32"
:
"i32"
}[
out_dtype
]
out_dtype_abbrv
=
{
"float16"
:
"f16"
,
"float32"
:
"f32"
,
"int8"
:
"i8"
,
"int32"
:
"i32"
}[
out_dtype
]
in_dtype_abbrv
=
{
"bfloat16"
:
"bf16"
,
...
...
@@ -176,7 +171,6 @@ class MatrixCoreIntrinEmitter:
self
.
b_preshuffle
=
b_preshuffle
def
get_ldmatrix_index_map
(
self
,
is_b
=
False
):
k_dim
=
self
.
k_dim
*
self
.
k_pack
transposed
=
self
.
a_transposed
if
not
is_b
else
self
.
b_transposed
if
k_dim
==
4
:
...
...
@@ -184,28 +178,42 @@ class MatrixCoreIntrinEmitter:
reverse_index_map
=
thread_id_shared_access_64x1_to_16x4_layout_A
if
is_b
:
index_map
=
shared_16x4_to_local_64x1_layout_A
if
transposed
else
shared_4x16_to_local_64x1_layout_B
reverse_index_map
=
thread_id_shared_access_64x1_to_16x4_layout_A
if
transposed
else
thread_id_shared_access_64x1_to_4x16_layout_B
reverse_index_map
=
(
thread_id_shared_access_64x1_to_16x4_layout_A
if
transposed
else
thread_id_shared_access_64x1_to_4x16_layout_B
)
elif
k_dim
==
16
:
index_map
=
shared_16x16_to_local_64x4_layout_B
if
transposed
else
shared_16x16_to_local_64x4_layout_A
reverse_index_map
=
thread_id_shared_access_64x4_to_16x16_layout_B
if
transposed
else
thread_id_shared_access_64x4_to_16x16_layout_A
reverse_index_map
=
(
thread_id_shared_access_64x4_to_16x16_layout_B
if
transposed
else
thread_id_shared_access_64x4_to_16x16_layout_A
)
if
is_b
:
index_map
=
shared_16x16_to_local_64x4_layout_A
if
transposed
else
shared_16x16_to_local_64x4_layout_B
reverse_index_map
=
thread_id_shared_access_64x4_to_16x16_layout_A
if
transposed
else
thread_id_shared_access_64x4_to_16x16_layout_B
reverse_index_map
=
(
thread_id_shared_access_64x4_to_16x16_layout_A
if
transposed
else
thread_id_shared_access_64x4_to_16x16_layout_B
)
elif
k_dim
==
32
:
index_map
=
shared_16x32_to_local_64x8_layout_B
if
transposed
else
shared_16x32_to_local_64x8_layout_A
reverse_index_map
=
thread_id_shared_access_64x8_to_16x32_layout_B
if
transposed
else
thread_id_shared_access_64x8_to_16x32_layout_A
reverse_index_map
=
(
thread_id_shared_access_64x8_to_16x32_layout_B
if
transposed
else
thread_id_shared_access_64x8_to_16x32_layout_A
)
if
is_b
:
index_map
=
shared_16x32_to_local_64x8_layout_A
if
transposed
else
shared_16x32_to_local_64x8_layout_B
reverse_index_map
=
thread_id_shared_access_64x8_to_16x32_layout_A
if
transposed
else
thread_id_shared_access_64x8_to_16x32_layout_B
reverse_index_map
=
(
thread_id_shared_access_64x8_to_16x32_layout_A
if
transposed
else
thread_id_shared_access_64x8_to_16x32_layout_B
)
elif
k_dim
==
64
:
index_map
=
shared_16x64_to_local_64x16_layout_B
if
transposed
else
shared_16x64_to_local_64x16_layout_A
reverse_index_map
=
thread_id_shared_access_64x16_to_16x64_layout_B
if
transposed
else
thread_id_shared_access_64x16_to_16x64_layout_A
reverse_index_map
=
(
thread_id_shared_access_64x16_to_16x64_layout_B
if
transposed
else
thread_id_shared_access_64x16_to_16x64_layout_A
)
if
is_b
:
index_map
=
shared_16x64_to_local_64x16_layout_A
if
transposed
else
shared_16x64_to_local_64x16_layout_B
reverse_index_map
=
thread_id_shared_access_64x16_to_16x64_layout_A
if
transposed
else
thread_id_shared_access_64x16_to_16x64_layout_B
reverse_index_map
=
(
thread_id_shared_access_64x16_to_16x64_layout_A
if
transposed
else
thread_id_shared_access_64x16_to_16x64_layout_B
)
else
:
raise
ValueError
(
"k_dim must be 4 or 16 or 32 or 64 currently"
)
...
...
@@ -227,14 +235,12 @@ class MatrixCoreIntrinEmitter:
else
:
return
self
.
thread_var
def
extract_thread_binding
(
self
,
thread_id
,
is_m_first
=
None
)
->
tuple
[
PrimExpr
,
PrimExpr
,
PrimExpr
]:
'''
is_m_first: True if the thread binding is in the form of (tx, warp_n, warp_m)
which represents [warp_size, block_row_warps (split n), block_col_warps (split m)]
Otherwise, it is in the form of [warp_size, block_col_warps (split m), block_row_warps (split n)]
'''
def
extract_thread_binding
(
self
,
thread_id
,
is_m_first
=
None
)
->
tuple
[
PrimExpr
,
PrimExpr
,
PrimExpr
]:
"""
is_m_first: True if the thread binding is in the form of (tx, warp_n, warp_m)
which represents [warp_size, block_row_warps (split n), block_col_warps (split m)]
Otherwise, it is in the form of [warp_size, block_col_warps (split m), block_row_warps (split n)]
"""
WARP_SIZE
=
self
.
WARP_SIZE
block_row_warps
=
self
.
block_row_warps
block_col_warps
=
self
.
block_col_warps
...
...
@@ -244,16 +250,18 @@ class MatrixCoreIntrinEmitter:
is_m_first
=
self
.
is_m_first
if
is_m_first
:
lane_id
,
warp_n
,
warp_m
=
thread_id
%
WARP_SIZE
,
(
thread_id
//
WARP_SIZE
)
%
block_col_warps
,
(
thread_id
//
(
WARP_SIZE
*
block_col_warps
))
%
block_row_warps
,
lane_id
,
warp_n
,
warp_m
=
(
thread_id
%
WARP_SIZE
,
(
thread_id
//
WARP_SIZE
)
%
block_col_warps
,
(
thread_id
//
(
WARP_SIZE
*
block_col_warps
))
%
block_row_warps
,
)
return
lane_id
,
warp_n
,
warp_m
else
:
lane_id
,
warp_m
,
warp_n
=
thread_id
%
WARP_SIZE
,
(
thread_id
//
WARP_SIZE
)
%
block_row_warps
,
(
thread_id
//
(
WARP_SIZE
*
block_row_warps
))
%
block_col_warps
,
lane_id
,
warp_m
,
warp_n
=
(
thread_id
%
WARP_SIZE
,
(
thread_id
//
WARP_SIZE
)
%
block_row_warps
,
(
thread_id
//
(
WARP_SIZE
*
block_row_warps
))
%
block_col_warps
,
)
return
lane_id
,
warp_n
,
warp_m
def
ldmatrix_a
(
self
,
A_local_buf
,
A_shared_buf
:
Buffer
|
BufferRegion
,
ki
,
rk
=
0
):
...
...
@@ -287,18 +295,14 @@ class MatrixCoreIntrinEmitter:
for
i
in
T
.
serial
(
warp_rows
):
for
local_id
in
T
.
vectorized
(
k_pack
*
local_size_a
):
row
,
col
=
T
.
meta_var
(
reverse_index_map
(
tx
,
local_id
))
l
,
r
=
(
rk
*
chunk
+
ki
*
(
k_pack
*
micro_size_k
),
warp_m
*
warp_row_tiles
+
i
*
micro_size_x
)
A_local_buf
[
i
*
k_pack
*
local_size_a
+
local_id
]
=
A_buf
[
A_base0
+
l
+
row
,
A_base1
+
r
+
col
]
l
,
r
=
(
rk
*
chunk
+
ki
*
(
k_pack
*
micro_size_k
),
warp_m
*
warp_row_tiles
+
i
*
micro_size_x
)
A_local_buf
[
i
*
k_pack
*
local_size_a
+
local_id
]
=
A_buf
[
A_base0
+
l
+
row
,
A_base1
+
r
+
col
]
else
:
for
i
in
T
.
serial
(
warp_rows
):
for
local_id
in
T
.
vectorized
(
k_pack
*
local_size_a
):
row
,
col
=
T
.
meta_var
(
reverse_index_map
(
tx
,
local_id
))
l
,
r
=
(
warp_m
*
warp_row_tiles
+
i
*
micro_size_x
,
rk
*
chunk
+
ki
*
(
k_pack
*
micro_size_k
))
A_local_buf
[
i
*
k_pack
*
local_size_a
+
local_id
]
=
A_buf
[
A_base0
+
l
+
row
,
A_base1
+
r
+
col
]
l
,
r
=
(
warp_m
*
warp_row_tiles
+
i
*
micro_size_x
,
rk
*
chunk
+
ki
*
(
k_pack
*
micro_size_k
))
A_local_buf
[
i
*
k_pack
*
local_size_a
+
local_id
]
=
A_buf
[
A_base0
+
l
+
row
,
A_base1
+
r
+
col
]
return
_warp_ldmatrix_a
(
A_local_buf
,
A_shared_buf
,
ki
,
thread_binding
,
rk
)
...
...
@@ -337,8 +341,7 @@ class MatrixCoreIntrinEmitter:
warp_n
*
warp_col_tiles
+
j
*
micro_size_y
,
rk
*
chunk
+
ki
*
(
k_pack
*
micro_size_k
),
)
B_local_buf
[
j
*
k_pack
*
local_size_b
+
local_id
]
=
B_buf
[
B_base0
+
l
+
row
,
B_base1
+
r
+
col
]
B_local_buf
[
j
*
k_pack
*
local_size_b
+
local_id
]
=
B_buf
[
B_base0
+
l
+
row
,
B_base1
+
r
+
col
]
else
:
for
j
in
T
.
serial
(
warp_cols
):
...
...
@@ -348,16 +351,11 @@ class MatrixCoreIntrinEmitter:
rk
*
chunk
+
ki
*
(
k_pack
*
micro_size_k
),
warp_n
*
warp_col_tiles
+
j
*
micro_size_y
,
)
B_local_buf
[
j
*
k_pack
*
local_size_b
+
local_id
]
=
B_buf
[
B_base0
+
l
+
row
,
B_base1
+
r
+
col
]
B_local_buf
[
j
*
k_pack
*
local_size_b
+
local_id
]
=
B_buf
[
B_base0
+
l
+
row
,
B_base1
+
r
+
col
]
return
_warp_ldmatrix_b
(
B_local_buf
,
B_shared_buf
,
ki
,
thread_binding
,
rk
)
def
mfma
(
self
,
A_local_buf
:
Buffer
,
B_local_buf
:
Buffer
,
C_local_buf
:
Buffer
,
k_inner
:
PrimExpr
|
None
=
0
):
def
mfma
(
self
,
A_local_buf
:
Buffer
,
B_local_buf
:
Buffer
,
C_local_buf
:
Buffer
,
k_inner
:
PrimExpr
|
None
=
0
):
warp_rows
=
self
.
warp_rows
warp_cols
=
self
.
warp_cols
local_size_a
=
self
.
local_size_a
...
...
@@ -421,14 +419,13 @@ class MatrixCoreIntrinEmitter:
for
local_id
in
T
.
vectorized
(
local_size_out
):
row
,
col
=
T
.
meta_var
(
mfma_store_index_map
(
tx
,
local_id
))
if
C_buf_dims
==
2
:
C_buf
[(
warp_m
*
warp_rows
+
i
)
*
M_DIM
+
row
,
(
warp_n
*
warp_cols
+
j
)
*
N_DIM
+
col
]
=
C_local_buf
[
i
*
(
warp_cols
*
local_size_out
)
+
j
*
local_size_out
+
local_id
]
C_buf
[(
warp_m
*
warp_rows
+
i
)
*
M_DIM
+
row
,
(
warp_n
*
warp_cols
+
j
)
*
N_DIM
+
col
]
=
C_local_buf
[
i
*
(
warp_cols
*
local_size_out
)
+
j
*
local_size_out
+
local_id
]
else
:
C_buf
[
warp_m
*
warp_rows
+
i
,
warp_n
*
warp_cols
+
j
,
row
,
col
]
=
C_local_buf
[
i
*
warp_cols
*
local_size_out
+
j
*
local_size_out
+
local_id
]
C_buf
[
warp_m
*
warp_rows
+
i
,
warp_n
*
warp_cols
+
j
,
row
,
col
]
=
C_local_buf
[
i
*
warp_cols
*
local_size_out
+
j
*
local_size_out
+
local_id
]
@
T
.
macro
def
_warp_stmatrix_global
(
C_local_buf
,
C_buf
,
thread_binding
):
...
...
@@ -436,18 +433,17 @@ class MatrixCoreIntrinEmitter:
for
i
,
j
in
T
.
grid
(
warp_rows
,
warp_cols
):
for
local_id
in
T
.
vectorized
(
local_size_out
):
row
,
col
=
T
.
meta_var
(
mfma_store_index_map
(
tx
,
local_id
))
C_buf
[(
pid_m
*
BLOCK_M
+
warp_m
*
warp_rows
+
i
)
*
M_DIM
+
row
,
(
pid_n
*
BLOCK_N
+
warp_n
*
warp_cols
+
j
)
*
N_DIM
+
col
]
=
C_local_buf
[
i
*
warp_cols
*
local_size_out
+
j
*
local_size_out
+
local_id
]
return
_warp_stmatrix_global
(
C_local_buf
,
C_buf
,
thread_binding
)
if
is_global
else
_warp_stmatrix_shared
(
C_local_buf
,
C_buf
,
thread_binding
)
def
make_mfma_load_layout
(
self
,
local_buf
:
Buffer
,
matrix
:
Literal
[
"A"
,
"B"
]
=
"A"
)
->
T
.
Fragment
:
C_buf
[
(
pid_m
*
BLOCK_M
+
warp_m
*
warp_rows
+
i
)
*
M_DIM
+
row
,
(
pid_n
*
BLOCK_N
+
warp_n
*
warp_cols
+
j
)
*
N_DIM
+
col
]
=
C_local_buf
[
i
*
warp_cols
*
local_size_out
+
j
*
local_size_out
+
local_id
]
return
(
_warp_stmatrix_global
(
C_local_buf
,
C_buf
,
thread_binding
)
if
is_global
else
_warp_stmatrix_shared
(
C_local_buf
,
C_buf
,
thread_binding
)
)
def
make_mfma_load_layout
(
self
,
local_buf
:
Buffer
,
matrix
:
Literal
[
"A"
,
"B"
]
=
"A"
)
->
T
.
Fragment
:
"""
Create a layout function for storing MFMA results into a fragment buffer.
...
...
@@ -468,6 +464,7 @@ class MatrixCoreIntrinEmitter:
If `local_buf` is not detected to be a fragment buffer.
"""
from
tilelang.utils
import
is_fragment
assert
matrix
in
[
"A"
,
"B"
],
"matrix should be either A or B"
matrix_is_a
:
bool
=
matrix
==
"A"
matrix_is_b
:
bool
=
matrix
==
"B"
...
...
@@ -506,11 +503,9 @@ class MatrixCoreIntrinEmitter:
transform_func
:
Callable
=
None
if
matrix_is_a
:
transform_func
=
transform_func_sr_a
if
is_sr_axis_order
else
lambda
i
,
j
:
transform_func_sr_a
(
j
,
i
)
transform_func
=
transform_func_sr_a
if
is_sr_axis_order
else
lambda
i
,
j
:
transform_func_sr_a
(
j
,
i
)
elif
matrix_is_b
:
transform_func
=
transform_func_sr_b
if
is_sr_axis_order
else
lambda
i
,
j
:
transform_func_sr_b
(
j
,
i
)
transform_func
=
transform_func_sr_b
if
is_sr_axis_order
else
lambda
i
,
j
:
transform_func_sr_b
(
j
,
i
)
else
:
raise
ValueError
(
f
"Unsupported matrix
{
matrix
}
"
)
...
...
@@ -543,8 +538,7 @@ class MatrixCoreIntrinEmitter:
return
local_id
base_fragment
=
T
.
Fragment
(
[
micro_size_s
,
micro_size_r
*
self
.
k_pack
]
if
is_sr_axis_order
else
[
micro_size_r
*
self
.
k_pack
,
micro_size_s
],
[
micro_size_s
,
micro_size_r
*
self
.
k_pack
]
if
is_sr_axis_order
else
[
micro_size_r
*
self
.
k_pack
,
micro_size_s
],
forward_thread_fn
=
forward_thread
,
forward_index_fn
=
forward_index
,
)
...
...
@@ -558,31 +552,19 @@ class MatrixCoreIntrinEmitter:
replicate
=
block_col_warps
if
matrix_is_a
else
block_row_warps
if
is_sr_axis_order
:
warp_fragment
=
base_fragment
.
repeat
([
warp_s
,
warp_r
],
repeat_on_thread
=
False
,
lower_dim_first
=
False
)
warp_fragment
=
base_fragment
.
repeat
([
warp_s
,
warp_r
],
repeat_on_thread
=
False
,
lower_dim_first
=
False
)
if
matrix_is_a
:
block_fragment
=
warp_fragment
.
repeat
([
block_s
,
1
],
repeat_on_thread
=
True
,
lower_dim_first
=
True
).
replicate
(
replicate
)
block_fragment
=
warp_fragment
.
repeat
([
block_s
,
1
],
repeat_on_thread
=
True
,
lower_dim_first
=
True
).
replicate
(
replicate
)
elif
matrix_is_b
:
block_fragment
=
warp_fragment
.
replicate
(
replicate
).
repeat
([
block_s
,
1
],
repeat_on_thread
=
True
,
lower_dim_first
=
True
)
block_fragment
=
warp_fragment
.
replicate
(
replicate
).
repeat
([
block_s
,
1
],
repeat_on_thread
=
True
,
lower_dim_first
=
True
)
else
:
raise
ValueError
(
f
"Unsupported matrix type
{
matrix
}
"
)
else
:
warp_fragment
=
base_fragment
.
repeat
([
warp_r
,
warp_s
],
repeat_on_thread
=
False
,
lower_dim_first
=
True
)
warp_fragment
=
base_fragment
.
repeat
([
warp_r
,
warp_s
],
repeat_on_thread
=
False
,
lower_dim_first
=
True
)
if
matrix_is_a
:
block_fragment
=
warp_fragment
.
repeat
([
1
,
block_s
],
repeat_on_thread
=
True
,
lower_dim_first
=
True
).
replicate
(
replicate
)
block_fragment
=
warp_fragment
.
repeat
([
1
,
block_s
],
repeat_on_thread
=
True
,
lower_dim_first
=
True
).
replicate
(
replicate
)
elif
matrix_is_b
:
block_fragment
=
warp_fragment
.
replicate
(
replicate
).
repeat
([
1
,
block_s
],
repeat_on_thread
=
True
,
lower_dim_first
=
True
)
block_fragment
=
warp_fragment
.
replicate
(
replicate
).
repeat
([
1
,
block_s
],
repeat_on_thread
=
True
,
lower_dim_first
=
True
)
else
:
raise
ValueError
(
f
"Unsupported matrix type
{
matrix
}
"
)
...
...
@@ -686,7 +668,6 @@ class MatrixCoreIntrinEmitter:
class
MatrixCorePreshuffleIntrinEmitter
(
MatrixCoreIntrinEmitter
):
def
__init__
(
self
,
a_dtype
:
str
=
"float16"
,
...
...
@@ -792,20 +773,20 @@ class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter):
rk
*
(
chunk
//
micro_size_k
)
+
ki
,
warp_m
*
warp_rows
+
i
,
)
A_local_buf
[
i
*
k_pack
*
local_size_a
+
local_id
]
=
A_shared_buf
[
l
,
r
,
row
,
col
]
A_local_buf
[
i
*
k_pack
*
local_size_a
+
local_id
]
=
A_shared_buf
[
l
,
r
,
row
,
col
]
else
:
print
(
self
.
a_preshuffle
)
for
i
in
T
.
serial
(
warp_rows
):
for
local_id
in
T
.
vectorized
(
k_pack
*
local_size_a
):
row
,
col
=
T
.
meta_var
(
reverse_index_map
(
tx
,
local_id
))
l
,
r
=
(
warp_m
*
warp_rows
+
i
,
rk
*
(
chunk
//
micro_size_k
)
+
ki
)
A_local_buf
[
i
*
k_pack
*
local_size_a
+
local_id
]
=
A_shared_buf
[
l
,
r
,
row
,
col
]
A_local_buf
[
i
*
k_pack
*
local_size_a
+
local_id
]
=
A_shared_buf
[
l
,
r
,
row
,
col
]
return
_warp_ldmatrix_a_global
(
A_local_buf
,
A_buf
,
ki
,
thread_binding
,
rk
)
if
is_global
else
_warp_ldmatrix_a_shared
(
A_local_buf
,
A_buf
,
ki
,
thread_binding
,
rk
)
return
(
_warp_ldmatrix_a_global
(
A_local_buf
,
A_buf
,
ki
,
thread_binding
,
rk
)
if
is_global
else
_warp_ldmatrix_a_shared
(
A_local_buf
,
A_buf
,
ki
,
thread_binding
,
rk
)
)
def
ldmatrix_b
(
self
,
B_local_buf
,
B_buf
,
ki
,
rk
=
0
,
pid_m
=
None
,
pid_n
=
None
):
warp_cols
=
self
.
warp_cols
...
...
@@ -867,8 +848,7 @@ class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter):
warp_n
*
warp_cols
+
j
,
rk
*
(
chunk
//
micro_size_k
)
+
ki
,
)
B_local_buf
[
j
*
k_pack
*
local_size_b
+
local_id
]
=
B_shared_buf
[
l
,
r
,
row
,
col
]
B_local_buf
[
j
*
k_pack
*
local_size_b
+
local_id
]
=
B_shared_buf
[
l
,
r
,
row
,
col
]
else
:
for
j
in
T
.
serial
(
warp_cols
):
for
local_id
in
T
.
vectorized
(
k_pack
*
local_size_b
):
...
...
@@ -877,9 +857,10 @@ class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter):
rk
*
(
chunk
//
micro_size_k
)
+
ki
,
warp_n
*
warp_cols
+
j
,
)
B_local_buf
[
j
*
k_pack
*
local_size_b
+
local_id
]
=
B_shared_buf
[
l
,
r
,
row
,
col
]
B_local_buf
[
j
*
k_pack
*
local_size_b
+
local_id
]
=
B_shared_buf
[
l
,
r
,
row
,
col
]
return
_warp_ldmatrix_b_global
(
B_local_buf
,
B_buf
,
ki
,
thread_binding
,
rk
)
if
is_global
else
_warp_ldmatrix_b_shared
(
B_local_buf
,
B_buf
,
ki
,
thread_binding
,
rk
)
return
(
_warp_ldmatrix_b_global
(
B_local_buf
,
B_buf
,
ki
,
thread_binding
,
rk
)
if
is_global
else
_warp_ldmatrix_b_shared
(
B_local_buf
,
B_buf
,
ki
,
thread_binding
,
rk
)
)
tilelang/intrinsics/mma_layout.py
View file @
29051439
...
...
@@ -153,14 +153,14 @@ def mma_load_a_32x16_to_shared_16x32_layout(thread_id, local_id):
def
mma_load_a_32x8_to_shared_16x16_layout
(
thread_id
,
local_id
):
"""
groupID = %laneid >> 2
threadID_in_group = %laneid % 4
groupID = %laneid >> 2
threadID_in_group = %laneid % 4
row = groupID for ai where 0 <= i < 2 || 4 <= i < 6
groupID + 8 Otherwise
row = groupID for ai where 0 <= i < 2 || 4 <= i < 6
groupID + 8 Otherwise
col = (threadID_in_group * 2) + (i & 0x1) for ai where i < 4
(threadID_in_group * 2) + (i & 0x1) + 8 for ai where i >= 4
col = (threadID_in_group * 2) + (i & 0x1) for ai where i < 4
(threadID_in_group * 2) + (i & 0x1) + 8 for ai where i >= 4
"""
row
=
(
thread_id
//
4
)
+
8
*
(
local_id
%
4
//
2
)
col
=
(
thread_id
%
4
)
*
2
+
(
local_id
%
2
)
+
8
*
(
local_id
//
4
)
...
...
@@ -175,13 +175,13 @@ def mma_load_b_32x16_to_shared_16x32_layout(thread_id, local_id):
def
mma_load_b_32x8_to_shared_16x16_layout
(
thread_id
,
local_id
):
"""
groupID = %laneid >> 2
threadID_in_group = %laneid % 4
groupID = %laneid >> 2
threadID_in_group = %laneid % 4
row = (threadID_in_group * 2) + (i & 0x1) for bi where i < 2
(threadID_in_group * 2) + (i & 0x1) + 8 for bi where i >= 2
row = (threadID_in_group * 2) + (i & 0x1) for bi where i < 2
(threadID_in_group * 2) + (i & 0x1) + 8 for bi where i >= 2
col = groupID
col = groupID
"""
col
=
(
thread_id
%
4
)
*
2
+
((
local_id
%
4
)
%
2
)
+
((
local_id
%
4
)
//
2
)
*
8
row
=
(
thread_id
//
4
)
+
8
*
(
local_id
//
4
)
...
...
tilelang/intrinsics/mma_macro_generator.py
View file @
29051439
...
...
@@ -191,6 +191,7 @@ class TensorCoreIntrinEmitter:
def
get_store_index_map
(
self
,
inverse
:
bool
=
False
)
->
IndexMap
:
from
.utils
import
mma_store_index_map
,
mma_store_index_map_fp64
warp_size
,
local_size_c
=
self
.
WARP_SIZE
,
self
.
local_size_out
if
DataType
(
self
.
accum_dtype
).
bits
==
64
:
index_map
=
IndexMap
.
from_func
(
mma_store_index_map_fp64
,
index_dtype
=
"int32"
)
...
...
@@ -201,10 +202,7 @@ class TensorCoreIntrinEmitter:
inverse_index_map
=
index_map
.
inverse
([
warp_size
,
local_size_c
])
return
inverse_index_map
def
extract_thread_binding
(
self
,
thread_id
:
PrimExpr
,
is_m_first
:
bool
|
None
=
None
)
->
tuple
[
PrimExpr
,
PrimExpr
,
PrimExpr
]:
def
extract_thread_binding
(
self
,
thread_id
:
PrimExpr
,
is_m_first
:
bool
|
None
=
None
)
->
tuple
[
PrimExpr
,
PrimExpr
,
PrimExpr
]:
"""
is_m_first: True if the thread binding is in the form of (tx, warp_n, warp_m)
which represents [warp_size, block_row_warps (split n), block_col_warps (split m)]
...
...
@@ -233,11 +231,7 @@ class TensorCoreIntrinEmitter:
)
return
lane_id
,
warp_n
,
warp_m
def
ldmatrix_a
(
self
,
A_local_buf
:
Buffer
,
A_shared_buf
:
Buffer
|
BufferRegion
,
ki
:
PrimExpr
,
rk
:
PrimExpr
|
None
=
0
):
def
ldmatrix_a
(
self
,
A_local_buf
:
Buffer
,
A_shared_buf
:
Buffer
|
BufferRegion
,
ki
:
PrimExpr
,
rk
:
PrimExpr
|
None
=
0
):
# Fast path for fp64: no ldmatrix support, do direct per-lane loads
if
DataType
(
self
.
a_dtype
).
bits
==
64
:
warp_row_tiles
=
self
.
warp_row_tiles
...
...
@@ -324,9 +318,7 @@ class TensorCoreIntrinEmitter:
for
i
in
T
.
serial
(
warp_rows
):
# Assign A_shared_buf_elem
wi
,
wk
=
warp_m
*
warp_row_tiles
+
i
*
micro_size_x
,
rk
*
chunk
+
ki
*
micro_size_k
A_shared_buf_elem
=
A_buf
[
A_base0
+
wk
,
A_base1
+
wi
]
if
a_transposed
else
A_buf
[
A_base0
+
wi
,
A_base1
+
wk
]
A_shared_buf_elem
=
A_buf
[
A_base0
+
wk
,
A_base1
+
wi
]
if
a_transposed
else
A_buf
[
A_base0
+
wi
,
A_base1
+
wk
]
if
ldmatrix_available
:
T
.
ptx_ldmatrix
(
...
...
@@ -343,20 +335,13 @@ class TensorCoreIntrinEmitter:
for
j
in
T
.
serial
(
local_size_a
):
mi
,
mk
=
mma_load_layout
(
tx
,
j
)
if
a_transposed
:
A_local_buf
[
i
*
local_size_a
+
j
]
=
A_buf
[
A_base0
+
wk
+
mk
,
A_base1
+
wi
+
mi
]
A_local_buf
[
i
*
local_size_a
+
j
]
=
A_buf
[
A_base0
+
wk
+
mk
,
A_base1
+
wi
+
mi
]
else
:
A_local_buf
[
i
*
local_size_a
+
j
]
=
A_buf
[
A_base0
+
wi
+
mi
,
A_base1
+
wk
+
mk
]
A_local_buf
[
i
*
local_size_a
+
j
]
=
A_buf
[
A_base0
+
wi
+
mi
,
A_base1
+
wk
+
mk
]
return
_warp_ldmatrix_a
(
A_local_buf
,
A_region
,
ki
,
thread_binding
,
rk
)
def
ldmatrix_b
(
self
,
B_local_buf
:
Buffer
,
B_shared_buf
:
Buffer
|
BufferRegion
,
ki
:
PrimExpr
,
rk
:
PrimExpr
|
None
=
0
):
def
ldmatrix_b
(
self
,
B_local_buf
:
Buffer
,
B_shared_buf
:
Buffer
|
BufferRegion
,
ki
:
PrimExpr
,
rk
:
PrimExpr
|
None
=
0
):
# Fast path for fp64: no ldmatrix support, do direct per-lane loads
if
DataType
(
self
.
b_dtype
).
bits
==
64
:
warp_col_tiles
=
self
.
warp_col_tiles
...
...
@@ -411,7 +396,7 @@ class TensorCoreIntrinEmitter:
B_base0
=
B_region
.
region
[
-
2
].
min
B_base1
=
B_region
.
region
[
-
1
].
min
B_stride_last
=
B_buf
.
shape
[
-
1
]
replicate_b
=
(
self
.
n_dim
==
16
)
replicate_b
=
self
.
n_dim
==
16
# ldmatrix cannot be used for int8 + trans case.
ldmatrix_available
=
not
(
DataType
(
b_dtype
).
bits
!=
16
and
not
b_transposed
)
...
...
@@ -448,9 +433,7 @@ class TensorCoreIntrinEmitter:
)
if
ldmatrix_available
:
B_shared_buf_elem
=
B_buf
[
B_base0
+
wi
,
B_base1
+
wk
]
if
b_transposed
else
B_buf
[
B_base0
+
wk
,
B_base1
+
wi
]
B_shared_buf_elem
=
B_buf
[
B_base0
+
wi
,
B_base1
+
wk
]
if
b_transposed
else
B_buf
[
B_base0
+
wk
,
B_base1
+
wi
]
T
.
ptx_ldmatrix
(
b_dtype
,
...
...
@@ -469,19 +452,13 @@ class TensorCoreIntrinEmitter:
for
j
in
T
.
serial
(
local_size_b
):
mi
,
mk
=
mma_load_layout
(
tx
,
j
)
if
b_transposed
:
B_local_buf
[
i
*
local_size_b
+
j
]
=
B_buf
[
B_base0
+
wi
+
mi
,
B_base1
+
wk
+
mk
]
B_local_buf
[
i
*
local_size_b
+
j
]
=
B_buf
[
B_base0
+
wi
+
mi
,
B_base1
+
wk
+
mk
]
else
:
B_local_buf
[
i
*
local_size_b
+
j
]
=
B_buf
[
B_base0
+
wk
+
mk
,
B_base1
+
wi
+
mi
]
B_local_buf
[
i
*
local_size_b
+
j
]
=
B_buf
[
B_base0
+
wk
+
mk
,
B_base1
+
wi
+
mi
]
return
_warp_ldmatrix_b
(
B_local_buf
,
B_shared_buf
,
ki
,
thread_binding
,
rk
)
def
mma
(
self
,
A_local_buf
:
Buffer
,
B_local_buf
:
Buffer
,
C_local_buf
:
Buffer
,
k_inner
:
PrimExpr
|
None
=
0
):
def
mma
(
self
,
A_local_buf
:
Buffer
,
B_local_buf
:
Buffer
,
C_local_buf
:
Buffer
,
k_inner
:
PrimExpr
|
None
=
0
):
warp_rows
=
self
.
warp_rows
warp_cols
=
self
.
warp_cols
local_size_a
=
self
.
local_size_a
...
...
@@ -492,7 +469,7 @@ class TensorCoreIntrinEmitter:
accum_dtype
=
self
.
accum_dtype
accum_dtype_abbrv
=
self
.
accum_dtype_abbrv
mma_prefix
=
self
.
mma_prefix
replicate_b
=
(
self
.
n_dim
==
16
)
replicate_b
=
self
.
n_dim
==
16
a_is_fragment
=
is_fragment
(
A_local_buf
)
b_is_fragment
=
is_fragment
(
B_local_buf
)
...
...
@@ -532,8 +509,7 @@ class TensorCoreIntrinEmitter:
B_local_buf
.
data
,
b_local_stride
+
j
*
local_size_b
+
lift
(
local_size_b
)
//
2
,
C_local_buf
.
data
,
i
*
warp_cols
*
local_size_out
+
j
*
local_size_out
+
lift
(
local_size_out
)
//
2
,
i
*
warp_cols
*
local_size_out
+
j
*
local_size_out
+
lift
(
local_size_out
)
//
2
,
T
.
bool
(
False
),
# saturate
)
...
...
@@ -568,14 +544,13 @@ class TensorCoreIntrinEmitter:
local_id
=
local_id_o
*
2
+
local_id_i
row
,
col
=
T
.
meta_var
(
mma_store_index_map
(
tx
,
local_id
))
if
C_buf_dims
==
2
:
C_buf
[(
warp_m
*
warp_rows
+
i
)
*
M_DIM
+
row
,
(
warp_n
*
warp_cols
+
j
)
*
n_dim
+
col
]
=
C_local_buf
[
i
*
(
warp_cols
*
local_size_out
)
+
j
*
local_size_out
+
local_id
]
C_buf
[(
warp_m
*
warp_rows
+
i
)
*
M_DIM
+
row
,
(
warp_n
*
warp_cols
+
j
)
*
n_dim
+
col
]
=
C_local_buf
[
i
*
(
warp_cols
*
local_size_out
)
+
j
*
local_size_out
+
local_id
]
else
:
C_buf
[
warp_m
*
warp_rows
+
i
,
warp_n
*
warp_cols
+
j
,
row
,
col
]
=
C_local_buf
[
i
*
(
warp_cols
*
local_size_out
)
+
j
*
local_size_out
+
local_id
]
C_buf
[
warp_m
*
warp_rows
+
i
,
warp_n
*
warp_cols
+
j
,
row
,
col
]
=
C_local_buf
[
i
*
(
warp_cols
*
local_size_out
)
+
j
*
local_size_out
+
local_id
]
@
T
.
macro
def
_warp_stmatrix_global
(
C_local_buf
,
C_buf
,
thread_binding
):
...
...
@@ -588,15 +563,15 @@ class TensorCoreIntrinEmitter:
C_buf
[
(
pid_m
*
BLOCK_M
+
warp_m
*
warp_rows
+
i
)
*
M_DIM
+
row
,
(
pid_n
*
BLOCK_N
+
warp_n
*
warp_cols
+
j
)
*
n_dim
+
col
,
]
=
C_local_buf
[
i
*
warp_cols
*
local_size_out
+
j
*
local_size_out
+
local_id
]
]
=
C_local_buf
[
i
*
warp_cols
*
local_size_out
+
j
*
local_size_out
+
local_id
]
return
(
_warp_stmatrix_global
(
C_local_buf
,
C_buf
,
thread_binding
)
if
is_global
else
_warp_stmatrix_shared
(
C_local_buf
,
C_buf
,
thread_binding
))
return
(
_warp_stmatrix_global
(
C_local_buf
,
C_buf
,
thread_binding
)
if
is_global
else
_warp_stmatrix_shared
(
C_local_buf
,
C_buf
,
thread_binding
)
)
def
make_mma_load_layout
(
self
,
local_buf
:
Buffer
,
matrix
:
Literal
[
"A"
,
"B"
]
=
"A"
)
->
T
.
Fragment
:
def
make_mma_load_layout
(
self
,
local_buf
:
Buffer
,
matrix
:
Literal
[
"A"
,
"B"
]
=
"A"
)
->
T
.
Fragment
:
"""
Create a layout function for storing MMA results into a fragment buffer.
This layout is used in conjunction with `inverse_mma_store_layout` to
...
...
@@ -619,6 +594,7 @@ class TensorCoreIntrinEmitter:
If `local_buf` is not detected to be a fragment buffer.
"""
from
tilelang.utils
import
is_fragment
assert
matrix
in
[
"A"
,
"B"
],
"matrix should be either A or B"
matrix_is_a
:
bool
=
matrix
==
"A"
matrix_is_b
:
bool
=
matrix
==
"B"
...
...
@@ -655,11 +631,9 @@ class TensorCoreIntrinEmitter:
# so the b matrix expected a transposed basic layout
transform_func
:
Callable
=
None
if
matrix_is_a
:
transform_func
=
transform_func_sr_a
if
is_sr_axis_order
else
lambda
i
,
j
:
transform_func_sr_a
(
j
,
i
)
transform_func
=
transform_func_sr_a
if
is_sr_axis_order
else
lambda
i
,
j
:
transform_func_sr_a
(
j
,
i
)
elif
matrix_is_b
:
transform_func
=
transform_func_sr_b
if
is_sr_axis_order
else
lambda
i
,
j
:
transform_func_sr_b
(
j
,
i
)
transform_func
=
transform_func_sr_b
if
is_sr_axis_order
else
lambda
i
,
j
:
transform_func_sr_b
(
j
,
i
)
else
:
raise
ValueError
(
f
"Unsupported matrix
{
matrix
}
"
)
...
...
@@ -706,31 +680,19 @@ class TensorCoreIntrinEmitter:
replicate
=
block_col_warps
if
matrix_is_a
else
block_row_warps
if
is_sr_axis_order
:
warp_fragment
=
base_fragment
.
repeat
([
warp_s
,
warp_r
],
repeat_on_thread
=
False
,
lower_dim_first
=
False
)
warp_fragment
=
base_fragment
.
repeat
([
warp_s
,
warp_r
],
repeat_on_thread
=
False
,
lower_dim_first
=
False
)
if
matrix_is_a
:
block_fragment
=
warp_fragment
.
repeat
([
block_s
,
1
],
repeat_on_thread
=
True
,
lower_dim_first
=
True
).
replicate
(
replicate
)
block_fragment
=
warp_fragment
.
repeat
([
block_s
,
1
],
repeat_on_thread
=
True
,
lower_dim_first
=
True
).
replicate
(
replicate
)
elif
matrix_is_b
:
block_fragment
=
warp_fragment
.
replicate
(
replicate
).
repeat
([
block_s
,
1
],
repeat_on_thread
=
True
,
lower_dim_first
=
True
)
block_fragment
=
warp_fragment
.
replicate
(
replicate
).
repeat
([
block_s
,
1
],
repeat_on_thread
=
True
,
lower_dim_first
=
True
)
else
:
raise
ValueError
(
f
"Unsupported matrix type
{
matrix
}
"
)
else
:
warp_fragment
=
base_fragment
.
repeat
([
warp_r
,
warp_s
],
repeat_on_thread
=
False
,
lower_dim_first
=
True
)
warp_fragment
=
base_fragment
.
repeat
([
warp_r
,
warp_s
],
repeat_on_thread
=
False
,
lower_dim_first
=
True
)
if
matrix_is_a
:
block_fragment
=
warp_fragment
.
repeat
([
1
,
block_s
],
repeat_on_thread
=
True
,
lower_dim_first
=
True
).
replicate
(
replicate
)
block_fragment
=
warp_fragment
.
repeat
([
1
,
block_s
],
repeat_on_thread
=
True
,
lower_dim_first
=
True
).
replicate
(
replicate
)
elif
matrix_is_b
:
block_fragment
=
warp_fragment
.
replicate
(
replicate
).
repeat
([
1
,
block_s
],
repeat_on_thread
=
True
,
lower_dim_first
=
True
)
block_fragment
=
warp_fragment
.
replicate
(
replicate
).
repeat
([
1
,
block_s
],
repeat_on_thread
=
True
,
lower_dim_first
=
True
)
else
:
raise
ValueError
(
f
"Unsupported matrix type
{
matrix
}
"
)
...
...
@@ -761,8 +723,7 @@ class TensorCoreIntrinEmitter:
from
tilelang.utils
import
is_fragment
shape
=
local_buf
.
shape
assert
is_fragment
(
local_buf
),
f
"local_buf
{
local_buf
}
must be a fragment, but got
{
local_buf
.
scope
()
}
"
assert
is_fragment
(
local_buf
),
f
"local_buf
{
local_buf
}
must be a fragment, but got
{
local_buf
.
scope
()
}
"
inverse_mma_store_layout
=
self
.
get_store_index_map
(
inverse
=
True
)
micro_size_x
,
micro_size_y
=
self
.
micro_size_x
,
self
.
micro_size_y
...
...
@@ -954,10 +915,12 @@ class TensorCoreIntrinEmitterWithLadderTransform(TensorCoreIntrinEmitter):
".b16"
,
A_local_buf
.
data
,
i
*
local_size_a
,
T
.
address_of
(
A_shared_buf
[
warp_m
*
warp_row_tiles
+
i
*
micro_size_x
,
rk
*
chunk
+
ki
*
micro_size_k
,
]),
T
.
address_of
(
A_shared_buf
[
warp_m
*
warp_row_tiles
+
i
*
micro_size_x
,
rk
*
chunk
+
ki
*
micro_size_k
,
]
),
get_ldmatrix_offset
(
"A"
,
tx
,
0
,
stride
,
a_dtype
,
a_transposed
),
)
elif
transform_kind_a
==
TransformKind
.
InterWarpTransform
:
...
...
@@ -1019,10 +982,8 @@ class TensorCoreIntrinEmitterWithLadderTransform(TensorCoreIntrinEmitter):
warp_m
*
warp_rows
+
j
,
rk
*
(
chunk
//
micro_size_k
)
+
ki
,
)
rii
,
rjj
=
(
tx
*
local_size_a
+
local_id
)
//
micro_size_k
,
(
tx
*
local_size_a
+
local_id
)
%
(
micro_size_k
)
A_local_buf
[
j
*
local_size_a
+
local_id
]
=
(
A_shared_buf
[
ri
,
rj
,
rii
,
rjj
])
rii
,
rjj
=
(
tx
*
local_size_a
+
local_id
)
//
micro_size_k
,
(
tx
*
local_size_a
+
local_id
)
%
(
micro_size_k
)
A_local_buf
[
j
*
local_size_a
+
local_id
]
=
A_shared_buf
[
ri
,
rj
,
rii
,
rjj
]
else
:
raise
ValueError
(
"Unsupported TransformKind for Input A"
)
...
...
@@ -1131,12 +1092,11 @@ class TensorCoreIntrinEmitterWithLadderTransform(TensorCoreIntrinEmitter):
warp_n
*
warp_cols
+
j
,
rk
*
(
chunk
//
micro_size_k
)
+
ki
,
)
rii
,
rjj
=
(
tx
*
local_size_dequantize
+
local_id
)
//
(
micro_size_k
//
num_elems_per_byte
),
(
tx
*
local_size_dequantize
+
local_id
)
%
(
micro_size_k
//
num_elems_per_byte
)
B_local_buf
[
j
*
local_size_dequantize
+
local_id
]
=
(
B_shared_buf
[
ri
,
rj
,
rii
,
rjj
])
rii
,
rjj
=
(
(
tx
*
local_size_dequantize
+
local_id
)
//
(
micro_size_k
//
num_elems_per_byte
),
(
tx
*
local_size_dequantize
+
local_id
)
%
(
micro_size_k
//
num_elems_per_byte
),
)
B_local_buf
[
j
*
local_size_dequantize
+
local_id
]
=
B_shared_buf
[
ri
,
rj
,
rii
,
rjj
]
else
:
raise
ValueError
(
"Unsupported TransformKind for Input B"
)
...
...
@@ -1195,7 +1155,6 @@ class TensorCoreIntrinEmitterWithLadderTransform(TensorCoreIntrinEmitter):
class
INT4TensorCoreIntrinEmitter
(
TensorCoreIntrinEmitter
):
def
mma
(
self
,
A_local_buf
,
B_local_buf
,
C_local_buf
):
warp_rows
=
self
.
warp_rows
warp_cols
=
self
.
warp_cols
...
...
@@ -1298,9 +1257,7 @@ class INT4TensorCoreIntrinEmitter(TensorCoreIntrinEmitter):
class
INT4TensorCoreIntrinEmitterWithLadderTransform
(
TensorCoreIntrinEmitterWithLadderTransform
):
def
mma
(
self
,
A_local_buf
,
B_local_buf
,
C_local_buf
):
warp_rows
=
self
.
warp_rows
warp_cols
=
self
.
warp_cols
local_size_a
=
self
.
local_size_a
...
...
tilelang/intrinsics/mma_sm70_layout.py
View file @
29051439
...
...
@@ -17,10 +17,8 @@ def shared_16x4_to_mma_b_32x4_layout_trans(row, col, rep):
def
mma_32x8_to_shared_16x16_layout_fp32
(
thread_id
,
local_id
):
row
=
(
thread_id
%
2
)
+
(
(
local_id
//
2
%
2
)
*
2
)
+
4
*
(
thread_id
//
16
)
+
(
thread_id
%
16
//
4
)
%
2
*
8
col
=
(
thread_id
%
4
//
2
)
*
2
+
(
thread_id
%
16
//
8
)
*
4
+
(
local_id
%
2
)
+
(
local_id
//
4
)
*
8
row
=
(
thread_id
%
2
)
+
((
local_id
//
2
%
2
)
*
2
)
+
4
*
(
thread_id
//
16
)
+
(
thread_id
%
16
//
4
)
%
2
*
8
col
=
(
thread_id
%
4
//
2
)
*
2
+
(
thread_id
%
16
//
8
)
*
4
+
(
local_id
%
2
)
+
(
local_id
//
4
)
*
8
return
row
,
col
...
...
@@ -31,7 +29,7 @@ def mma_32x8_to_shared_16x16_layout_fp16(thread_id, local_id):
def
mma_load_a_32x4_to_shared_16x4_layout
(
thread_id
,
local_id
):
row
=
(
thread_id
%
4
)
+
(
4
*
((
(
thread_id
//
16
+
thread_id
%
16
//
4
*
2
)
)
%
4
))
row
=
(
thread_id
%
4
)
+
(
4
*
((
thread_id
//
16
+
thread_id
%
16
//
4
*
2
)
%
4
))
col
=
local_id
return
row
,
col
...
...
tilelang/intrinsics/mma_sm70_macro_generator.py
View file @
29051439
...
...
@@ -147,18 +147,15 @@ class TensorCoreIntrinEmitter:
def
get_store_index_map
(
self
,
inverse
:
bool
=
False
)
->
IndexMap
:
warp_size
,
local_size_c
=
self
.
WARP_SIZE
,
self
.
local_size_out
index_map
=
IndexMap
.
from_func
(
mma_32x8_to_shared_16x16_layout_fp32
i
f
self
.
accum_dtype
==
"float32"
else
mma_32x8_to_shared_16x16_layout_fp16
,
index_dtype
=
"int32"
)
mma_32x8_to_shared_16x16_layout_fp32
if
self
.
accum_dtype
==
"float32"
else
mma_32x8_to_shared_16x16_layout_fp16
,
i
ndex_dtype
=
"int32"
,
)
if
not
inverse
:
return
index_map
inverse_index_map
=
index_map
.
inverse
([
warp_size
,
local_size_c
])
return
inverse_index_map
def
extract_thread_binding
(
self
,
thread_id
:
PrimExpr
,
is_m_first
:
bool
|
None
=
None
)
->
tuple
[
PrimExpr
,
PrimExpr
,
PrimExpr
]:
def
extract_thread_binding
(
self
,
thread_id
:
PrimExpr
,
is_m_first
:
bool
|
None
=
None
)
->
tuple
[
PrimExpr
,
PrimExpr
,
PrimExpr
]:
"""
is_m_first: True if the thread binding is in the form of (tx, warp_n, warp_m)
which represents [warp_size, block_row_warps (split n), block_col_warps (split m)]
...
...
@@ -187,11 +184,7 @@ class TensorCoreIntrinEmitter:
)
return
lane_id
,
warp_n
,
warp_m
def
ldmatrix_a
(
self
,
A_local_buf
:
Buffer
,
A_shared_buf
:
Buffer
|
BufferRegion
,
ki
:
PrimExpr
,
rk
:
PrimExpr
|
None
=
0
):
def
ldmatrix_a
(
self
,
A_local_buf
:
Buffer
,
A_shared_buf
:
Buffer
|
BufferRegion
,
ki
:
PrimExpr
,
rk
:
PrimExpr
|
None
=
0
):
warp_row_tiles
=
self
.
warp_row_tiles
warp_rows
=
self
.
warp_rows
chunk
=
self
.
chunk
...
...
@@ -231,11 +224,7 @@ class TensorCoreIntrinEmitter:
return
_warp_ldmatrix_a
(
A_local_buf
,
A_region
,
ki
,
thread_binding
,
rk
)
def
ldmatrix_b
(
self
,
B_local_buf
:
Buffer
,
B_shared_buf
:
Buffer
|
BufferRegion
,
ki
:
PrimExpr
,
rk
:
PrimExpr
|
None
=
0
):
def
ldmatrix_b
(
self
,
B_local_buf
:
Buffer
,
B_shared_buf
:
Buffer
|
BufferRegion
,
ki
:
PrimExpr
,
rk
:
PrimExpr
|
None
=
0
):
warp_col_tiles
=
self
.
warp_col_tiles
warp_cols
=
self
.
warp_cols
chunk
=
self
.
chunk
...
...
@@ -274,20 +263,14 @@ class TensorCoreIntrinEmitter:
for
j
in
T
.
vectorized
(
local_size_b
):
if
b_transposed
:
mi
,
mk
=
mma_load_layout
(
tx
,
j
)
B_local_buf
[
i
*
local_size_b
+
j
]
=
B_buf
[
B_base0
+
wi
+
mi
,
B_base1
+
wk
+
mk
]
B_local_buf
[
i
*
local_size_b
+
j
]
=
B_buf
[
B_base0
+
wi
+
mi
,
B_base1
+
wk
+
mk
]
else
:
mk
,
mi
=
mma_load_layout
(
tx
,
j
)
B_local_buf
[
i
*
local_size_b
+
j
]
=
B_buf
[
B_base0
+
wk
+
mk
,
B_base1
+
wi
+
mi
]
B_local_buf
[
i
*
local_size_b
+
j
]
=
B_buf
[
B_base0
+
wk
+
mk
,
B_base1
+
wi
+
mi
]
return
_warp_ldmatrix_b
(
B_local_buf
,
B_region
,
ki
,
thread_binding
,
rk
)
def
mma
(
self
,
A_local_buf
:
Buffer
,
B_local_buf
:
Buffer
,
C_local_buf
:
Buffer
,
k_inner
:
PrimExpr
|
None
=
0
):
def
mma
(
self
,
A_local_buf
:
Buffer
,
B_local_buf
:
Buffer
,
C_local_buf
:
Buffer
,
k_inner
:
PrimExpr
|
None
=
0
):
warp_rows
=
self
.
warp_rows
warp_cols
=
self
.
warp_cols
local_size_a
=
self
.
local_size_a
...
...
@@ -326,9 +309,7 @@ class TensorCoreIntrinEmitter:
return
_warp_mma
(
A_local_buf
,
B_local_buf
,
C_local_buf
)
def
make_mma_load_layout
(
self
,
local_buf
:
Buffer
,
matrix
:
Literal
[
"A"
,
"B"
]
=
"A"
)
->
T
.
Fragment
:
def
make_mma_load_layout
(
self
,
local_buf
:
Buffer
,
matrix
:
Literal
[
"A"
,
"B"
]
=
"A"
)
->
T
.
Fragment
:
"""
Create a layout function for storing MMA results into a fragment buffer.
This layout is used in conjunction with `inverse_mma_store_layout` to
...
...
@@ -351,6 +332,7 @@ class TensorCoreIntrinEmitter:
If `local_buf` is not detected to be a fragment buffer.
"""
from
tilelang.utils
import
is_fragment
assert
matrix
in
[
"A"
,
"B"
],
"matrix should be either A or B"
matrix_is_a
:
bool
=
matrix
==
"A"
matrix_is_b
:
bool
=
matrix
==
"B"
...
...
@@ -383,11 +365,9 @@ class TensorCoreIntrinEmitter:
# so the b matrix expected a transposed basic layout
transform_func
:
Callable
=
None
if
matrix_is_a
:
transform_func
=
transform_func_sr_a
if
is_sr_axis_order
else
lambda
i
,
j
:
transform_func_sr_a
(
j
,
i
)
transform_func
=
transform_func_sr_a
if
is_sr_axis_order
else
lambda
i
,
j
:
transform_func_sr_a
(
j
,
i
)
elif
matrix_is_b
:
transform_func
=
transform_func_sr_b
if
is_sr_axis_order
else
lambda
i
,
j
:
transform_func_rs_b
(
i
,
j
)
transform_func
=
transform_func_sr_b
if
is_sr_axis_order
else
lambda
i
,
j
:
transform_func_rs_b
(
i
,
j
)
else
:
raise
ValueError
(
f
"Unsupported matrix
{
matrix
}
"
)
...
...
@@ -413,9 +393,8 @@ class TensorCoreIntrinEmitter:
return
lane_id
,
local_id
base_fragment
=
T
.
Fragment
(
[
micro_size_s
,
micro_size_r
]
if
is_sr_axis_order
else
[
micro_size_r
,
micro_size_s
],
forward_fn
=
forward
,
replicate
=
2
)
[
micro_size_s
,
micro_size_r
]
if
is_sr_axis_order
else
[
micro_size_r
,
micro_size_s
],
forward_fn
=
forward
,
replicate
=
2
)
warp_rows
,
warp_cols
=
self
.
warp_rows
,
self
.
warp_cols
chunk
=
self
.
chunk
...
...
@@ -426,31 +405,19 @@ class TensorCoreIntrinEmitter:
replicate
=
block_col_warps
if
matrix_is_a
else
block_row_warps
if
is_sr_axis_order
:
warp_fragment
=
base_fragment
.
repeat
([
warp_s
,
warp_r
],
repeat_on_thread
=
False
,
lower_dim_first
=
False
)
warp_fragment
=
base_fragment
.
repeat
([
warp_s
,
warp_r
],
repeat_on_thread
=
False
,
lower_dim_first
=
False
)
if
matrix_is_a
:
block_fragment
=
warp_fragment
.
repeat
([
block_s
,
1
],
repeat_on_thread
=
True
,
lower_dim_first
=
True
).
replicate
(
replicate
)
block_fragment
=
warp_fragment
.
repeat
([
block_s
,
1
],
repeat_on_thread
=
True
,
lower_dim_first
=
True
).
replicate
(
replicate
)
elif
matrix_is_b
:
block_fragment
=
warp_fragment
.
replicate
(
replicate
).
repeat
([
block_s
,
1
],
repeat_on_thread
=
True
,
lower_dim_first
=
True
)
block_fragment
=
warp_fragment
.
replicate
(
replicate
).
repeat
([
block_s
,
1
],
repeat_on_thread
=
True
,
lower_dim_first
=
True
)
else
:
raise
ValueError
(
f
"Unsupported matrix type
{
matrix
}
"
)
else
:
warp_fragment
=
base_fragment
.
repeat
([
warp_r
,
warp_s
],
repeat_on_thread
=
False
,
lower_dim_first
=
True
)
warp_fragment
=
base_fragment
.
repeat
([
warp_r
,
warp_s
],
repeat_on_thread
=
False
,
lower_dim_first
=
True
)
if
matrix_is_a
:
block_fragment
=
warp_fragment
.
repeat
([
1
,
block_s
],
repeat_on_thread
=
True
,
lower_dim_first
=
True
).
replicate
(
replicate
)
block_fragment
=
warp_fragment
.
repeat
([
1
,
block_s
],
repeat_on_thread
=
True
,
lower_dim_first
=
True
).
replicate
(
replicate
)
elif
matrix_is_b
:
block_fragment
=
warp_fragment
.
replicate
(
replicate
).
repeat
([
1
,
block_s
],
repeat_on_thread
=
True
,
lower_dim_first
=
True
)
block_fragment
=
warp_fragment
.
replicate
(
replicate
).
repeat
([
1
,
block_s
],
repeat_on_thread
=
True
,
lower_dim_first
=
True
)
else
:
raise
ValueError
(
f
"Unsupported matrix type
{
matrix
}
"
)
...
...
tilelang/intrinsics/mma_sp_layout.py
View file @
29051439
...
...
@@ -72,56 +72,47 @@ def get_logical_id_32bit(thread_id: int) -> int:
return
(
thread_id
//
4
)
*
2
+
(
thread_id
%
4
)
%
2
def
metadata_8bit_load_32x4_to_shared_16x4_layout_32bit
(
thread_id
:
int
,
local_id
:
int
)
->
tuple
[
int
,
int
]:
def
metadata_8bit_load_32x4_to_shared_16x4_layout_32bit
(
thread_id
:
int
,
local_id
:
int
)
->
tuple
[
int
,
int
]:
logical_id
=
get_logical_id_32bit
(
thread_id
)
row
=
logical_id
//
4
+
local_id
*
8
col
=
logical_id
%
4
return
row
,
col
def
metadata_16bit_load_32x2_to_shared_16x2_layout_32bit
(
thread_id
:
int
,
local_id
:
int
)
->
tuple
[
int
,
int
]:
def
metadata_16bit_load_32x2_to_shared_16x2_layout_32bit
(
thread_id
:
int
,
local_id
:
int
)
->
tuple
[
int
,
int
]:
logical_id
=
get_logical_id_32bit
(
thread_id
)
row
=
logical_id
//
2
+
local_id
*
8
col
=
logical_id
%
2
return
row
,
col
def
metadata_8bit_load_32x4_to_shared_16x4_layout_16bit
(
thread_id
:
int
,
local_id
:
int
)
->
tuple
[
int
,
int
]:
return
metadata_8bit_load_32x4_to_shared_16x4_layout_32bit
(
thread_id
,
local_id
)
# same mapping for 16bit and 32bit
def
metadata_8bit_load_32x4_to_shared_16x4_layout_16bit
(
thread_id
:
int
,
local_id
:
int
)
->
tuple
[
int
,
int
]:
return
metadata_8bit_load_32x4_to_shared_16x4_layout_32bit
(
thread_id
,
local_id
)
# same mapping for 16bit and 32bit
def
metadata_16bit_load_32x2_to_shared_16x2_layout_16bit
(
thread_id
:
int
,
local_id
:
int
)
->
tuple
[
int
,
int
]:
return
metadata_16bit_load_32x2_to_shared_16x2_layout_32bit
(
thread_id
,
local_id
)
# same mapping for 16bit and 32bit
def
metadata_16bit_load_32x2_to_shared_16x2_layout_16bit
(
thread_id
:
int
,
local_id
:
int
)
->
tuple
[
int
,
int
]:
return
metadata_16bit_load_32x2_to_shared_16x2_layout_32bit
(
thread_id
,
local_id
)
# same mapping for 16bit and 32bit
def
get_logical_id_8bit
(
thread_id
:
int
)
->
int
:
return
thread_id
def
metadata_8bit_load_32x4_to_shared_16x4_layout_8bit
(
thread_id
:
int
,
local_id
:
int
)
->
tuple
[
int
,
int
]:
def
metadata_8bit_load_32x4_to_shared_16x4_layout_8bit
(
thread_id
:
int
,
local_id
:
int
)
->
tuple
[
int
,
int
]:
logical_id
=
get_logical_id_8bit
(
thread_id
)
row
=
logical_id
//
2
+
local_id
*
8
col
=
(
logical_id
%
4
)
//
2
*
4
+
local_id
return
row
,
col
def
metadata_16bit_load_32x2_to_shared_16x4_layout_8bit
(
thread_id
:
int
,
local_id
:
int
)
->
tuple
[
int
,
int
]:
def
metadata_16bit_load_32x2_to_shared_16x4_layout_8bit
(
thread_id
:
int
,
local_id
:
int
)
->
tuple
[
int
,
int
]:
logical_id
=
get_logical_id_8bit
(
thread_id
)
row
=
logical_id
//
2
+
local_id
*
8
col
=
(
logical_id
%
4
)
//
2
*
2
+
local_id
return
row
,
col
def
metadata_32bit_load_32x1_to_shared_16x2_layout_8bit
(
thread_id
:
int
,
local_id
:
int
)
->
tuple
[
int
,
int
]:
def
metadata_32bit_load_32x1_to_shared_16x2_layout_8bit
(
thread_id
:
int
,
local_id
:
int
)
->
tuple
[
int
,
int
]:
# local_id is always 0
logical_id
=
get_logical_id_8bit
(
thread_id
)
row
=
logical_id
//
4
+
(
logical_id
%
2
)
*
8
...
...
tilelang/intrinsics/mma_sp_macro_generator.py
View file @
29051439
...
...
@@ -190,8 +190,7 @@ class SparseTensorCoreIntrinEmitter:
def
_initialize_local_size
(
self
,
m_dim
=
16
,
n_dim
=
16
,
k_dim
=
16
,
warp_size
=
32
):
self
.
local_size_a
=
(
m_dim
*
k_dim
)
//
warp_size
//
self
.
SPARSE_FACTOR
self
.
local_size_e
=
(
m_dim
*
k_dim
)
//
self
.
e_factor
//
warp_size
*
self
.
E_REPLICATE_FACTOR
[
self
.
a_dtype
]
self
.
local_size_e
=
(
m_dim
*
k_dim
)
//
self
.
e_factor
//
warp_size
*
self
.
E_REPLICATE_FACTOR
[
self
.
a_dtype
]
self
.
local_size_b
=
(
n_dim
*
k_dim
)
//
warp_size
self
.
local_size_out
=
(
m_dim
*
n_dim
)
//
warp_size
...
...
@@ -257,10 +256,7 @@ class SparseTensorCoreIntrinEmitter:
inverse_index_map
=
index_map
.
inverse
([
warp_size
,
local_size_c
])
return
inverse_index_map
def
extract_thread_binding
(
self
,
thread_id
:
PrimExpr
,
is_m_first
:
bool
|
None
=
None
)
->
tuple
[
PrimExpr
,
PrimExpr
,
PrimExpr
]:
def
extract_thread_binding
(
self
,
thread_id
:
PrimExpr
,
is_m_first
:
bool
|
None
=
None
)
->
tuple
[
PrimExpr
,
PrimExpr
,
PrimExpr
]:
"""
is_m_first: True if the thread binding is in the form of (tx, warp_n, warp_m)
which represents [warp_size, block_row_warps (split n), block_col_warps (split m)]
...
...
@@ -330,8 +326,7 @@ class SparseTensorCoreIntrinEmitter:
for
i
in
T
.
serial
(
warp_rows
):
# Assign A_shared_buf_elem
wi
,
wk
=
warp_m
*
warp_row_tiles
+
i
*
micro_size_x
,
(
rk
*
warp_k
+
ki
*
micro_size_k
)
//
self
.
SPARSE_FACTOR
wi
,
wk
=
warp_m
*
warp_row_tiles
+
i
*
micro_size_x
,
(
rk
*
warp_k
+
ki
*
micro_size_k
)
//
self
.
SPARSE_FACTOR
A_shared_buf_elem
=
A_shared_buf
[
wk
,
wi
]
if
a_transposed
else
A_shared_buf
[
wi
,
wk
]
if
ldmatrix_available
:
...
...
@@ -348,10 +343,9 @@ class SparseTensorCoreIntrinEmitter:
else
:
for
j
in
T
.
serial
(
local_size_a
):
mi
,
mk
=
mma_load_layout
(
tx
,
j
)
A_local_buf
[
i
*
local_size_a
+
j
]
=
A_shared_buf
[
wk
+
mk
,
wi
+
mi
]
if
a_transposed
else
A_shared_buf
[
wi
+
mi
,
wk
+
mk
]
A_local_buf
[
i
*
local_size_a
+
j
]
=
(
A_shared_buf
[
wk
+
mk
,
wi
+
mi
]
if
a_transposed
else
A_shared_buf
[
wi
+
mi
,
wk
+
mk
]
)
return
_warp_ldmatrix_a
(
A_local_buf
,
A_shared_buf
,
ki
,
thread_binding
,
rk
)
...
...
@@ -412,14 +406,10 @@ class SparseTensorCoreIntrinEmitter:
tx
,
_
,
warp_m
=
self
.
extract_thread_binding
(
thread_binding
)
for
i
in
T
.
serial
(
warp_rows
):
# Assign E_shared_buf_elem
wi
,
wk
=
warp_m
*
warp_row_tiles
+
i
*
micro_size_x
,
(
rk
*
warp_k
+
ki
*
micro_size_k
)
//
self
.
e_factor
wi
,
wk
=
warp_m
*
warp_row_tiles
+
i
*
micro_size_x
,
(
rk
*
warp_k
+
ki
*
micro_size_k
)
//
self
.
e_factor
for
j
in
T
.
serial
(
local_size_e
):
mi
,
mk
=
mma_load_layout
(
tx
,
j
)
E_local_buf
[
i
*
local_size_e
+
j
]
=
E_shared_buf
[
wk
+
mk
,
wi
+
mi
]
if
trans
else
E_shared_buf
[
wi
+
mi
,
wk
+
mk
]
E_local_buf
[
i
*
local_size_e
+
j
]
=
E_shared_buf
[
wk
+
mk
,
wi
+
mi
]
if
trans
else
E_shared_buf
[
wi
+
mi
,
wk
+
mk
]
return
_warp_ldmatrix_e
(
E_local_buf
,
E_shared_buf
,
ki
,
thread_binding
,
rk
)
...
...
@@ -433,7 +423,7 @@ class SparseTensorCoreIntrinEmitter:
b_dtype
=
self
.
b_dtype
b_transposed
=
self
.
b_transposed
thread_binding
=
self
.
get_thread_binding
()
replicate_b
=
(
self
.
n_dim
==
16
)
replicate_b
=
self
.
n_dim
==
16
# ldmatrix cannot be used for int8 + trans case.
ldmatrix_available
=
not
(
DataType
(
b_dtype
).
bits
!=
16
and
not
b_transposed
)
...
...
@@ -470,8 +460,7 @@ class SparseTensorCoreIntrinEmitter:
)
if
ldmatrix_available
:
B_shared_buf_elem
=
B_shared_buf
[
wi
,
wk
]
if
b_transposed
else
B_shared_buf
[
wk
,
wi
]
B_shared_buf_elem
=
B_shared_buf
[
wi
,
wk
]
if
b_transposed
else
B_shared_buf
[
wk
,
wi
]
if
replicate_b
:
T
.
ptx_ldmatrix
(
...
...
@@ -493,9 +482,7 @@ class SparseTensorCoreIntrinEmitter:
B_local_buf
.
data
,
i
*
local_size_b
+
lift
(
local_size_b
)
//
2
,
T
.
address_of
(
B_shared_buf_elem
),
get_ldmatrix_offset_b
(
"B"
,
tx
,
lift
(
local_size_b
)
//
2
,
stride
,
b_dtype
,
b_transposed
),
get_ldmatrix_offset_b
(
"B"
,
tx
,
lift
(
local_size_b
)
//
2
,
stride
,
b_dtype
,
b_transposed
),
)
else
:
T
.
ptx_ldmatrix
(
...
...
@@ -514,19 +501,13 @@ class SparseTensorCoreIntrinEmitter:
# must be transposed.
for
j
in
T
.
serial
(
local_size_b
):
mi
,
mk
=
mma_load_layout
(
tx
,
j
)
B_local_buf
[
i
*
local_size_b
+
j
]
=
B_shared_buf
[
wi
+
mi
,
wk
+
mk
]
if
b_transposed
else
B_shared_buf
[
wk
+
mk
,
wi
+
mi
]
B_local_buf
[
i
*
local_size_b
+
j
]
=
(
B_shared_buf
[
wi
+
mi
,
wk
+
mk
]
if
b_transposed
else
B_shared_buf
[
wk
+
mk
,
wi
+
mi
]
)
return
_warp_ldmatrix_b
(
B_local_buf
,
B_shared_buf
,
ki
,
thread_binding
,
rk
)
def
mma_sp
(
self
,
A_local_buf
:
Buffer
,
E_local_buf
:
Buffer
,
B_local_buf
:
Buffer
,
C_local_buf
:
Buffer
,
k_inner
:
PrimExpr
=
0
):
def
mma_sp
(
self
,
A_local_buf
:
Buffer
,
E_local_buf
:
Buffer
,
B_local_buf
:
Buffer
,
C_local_buf
:
Buffer
,
k_inner
:
PrimExpr
=
0
):
warp_rows
=
self
.
warp_rows
warp_cols
=
self
.
warp_cols
local_size_a
=
self
.
local_size_a
...
...
@@ -538,7 +519,7 @@ class SparseTensorCoreIntrinEmitter:
accum_dtype
=
self
.
accum_dtype
accum_dtype_abbrv
=
self
.
accum_dtype_abbrv
mma_prefix
=
self
.
mma_prefix
replicate_b
=
(
self
.
n_dim
==
16
)
replicate_b
=
self
.
n_dim
==
16
a_is_fragment
=
is_fragment
(
A_local_buf
)
e_is_fragment
=
is_fragment
(
E_local_buf
)
...
...
@@ -584,8 +565,7 @@ class SparseTensorCoreIntrinEmitter:
B_local_buf
.
data
,
b_local_stride
+
j
*
local_size_b
+
lift
(
local_size_b
)
//
2
,
C_local_buf
.
data
,
i
*
warp_cols
*
local_size_out
+
j
*
local_size_out
+
lift
(
local_size_out
)
//
2
,
i
*
warp_cols
*
local_size_out
+
j
*
local_size_out
+
lift
(
local_size_out
)
//
2
,
E_local_buf
.
data
,
# metadata
e_local_stride
+
i
*
local_size_e
,
# metadata offset
self
.
SPARSE_SELECTOR
,
# sparse_selector
...
...
@@ -623,14 +603,13 @@ class SparseTensorCoreIntrinEmitter:
local_id
=
local_id_o
*
2
+
local_id_i
row
,
col
=
T
.
meta_var
(
mma_store_index_map
(
tx
,
local_id
))
if
C_buf_dims
==
2
:
C_buf
[(
warp_m
*
warp_rows
+
i
)
*
M_DIM
+
row
,
(
warp_n
*
warp_cols
+
j
)
*
n_dim
+
col
]
=
C_local_buf
[
i
*
(
warp_cols
*
local_size_out
)
+
j
*
local_size_out
+
local_id
]
C_buf
[(
warp_m
*
warp_rows
+
i
)
*
M_DIM
+
row
,
(
warp_n
*
warp_cols
+
j
)
*
n_dim
+
col
]
=
C_local_buf
[
i
*
(
warp_cols
*
local_size_out
)
+
j
*
local_size_out
+
local_id
]
else
:
C_buf
[
warp_m
*
warp_rows
+
i
,
warp_n
*
warp_cols
+
j
,
row
,
col
]
=
C_local_buf
[
i
*
(
warp_cols
*
local_size_out
)
+
j
*
local_size_out
+
local_id
]
C_buf
[
warp_m
*
warp_rows
+
i
,
warp_n
*
warp_cols
+
j
,
row
,
col
]
=
C_local_buf
[
i
*
(
warp_cols
*
local_size_out
)
+
j
*
local_size_out
+
local_id
]
@
T
.
macro
def
_warp_stmatrix_global
(
C_local_buf
,
C_buf
,
thread_binding
):
...
...
@@ -643,15 +622,15 @@ class SparseTensorCoreIntrinEmitter:
C_buf
[
(
pid_m
*
BLOCK_M
+
warp_m
*
warp_rows
+
i
)
*
M_DIM
+
row
,
(
pid_n
*
BLOCK_N
+
warp_n
*
warp_cols
+
j
)
*
n_dim
+
col
,
]
=
C_local_buf
[
i
*
warp_cols
*
local_size_out
+
j
*
local_size_out
+
local_id
]
]
=
C_local_buf
[
i
*
warp_cols
*
local_size_out
+
j
*
local_size_out
+
local_id
]
return
(
_warp_stmatrix_global
(
C_local_buf
,
C_buf
,
thread_binding
)
if
is_global
else
_warp_stmatrix_shared
(
C_local_buf
,
C_buf
,
thread_binding
))
return
(
_warp_stmatrix_global
(
C_local_buf
,
C_buf
,
thread_binding
)
if
is_global
else
_warp_stmatrix_shared
(
C_local_buf
,
C_buf
,
thread_binding
)
)
def
make_mma_load_layout
(
self
,
local_buf
:
Buffer
,
matrix
:
Literal
[
"A"
,
"B"
]
=
"A"
)
->
T
.
Fragment
:
def
make_mma_load_layout
(
self
,
local_buf
:
Buffer
,
matrix
:
Literal
[
"A"
,
"B"
]
=
"A"
)
->
T
.
Fragment
:
"""
Create a layout function for storing MMA results into a fragment buffer.
This layout is used in conjunction with `inverse_mma_store_layout` to
...
...
@@ -674,6 +653,7 @@ class SparseTensorCoreIntrinEmitter:
If `local_buf` is not detected to be a fragment buffer.
"""
from
tilelang.utils
import
is_fragment
assert
matrix
in
[
"A"
,
"B"
],
"matrix should be either A or B"
matrix_is_a
:
bool
=
matrix
==
"A"
matrix_is_b
:
bool
=
matrix
==
"B"
...
...
@@ -710,11 +690,9 @@ class SparseTensorCoreIntrinEmitter:
# so the b matrix expected a transposed basic layout
transform_func
:
Callable
=
None
if
matrix_is_a
:
transform_func
=
transform_func_sr_a
if
is_sr_axis_order
else
lambda
i
,
j
:
transform_func_sr_a
(
j
,
i
)
transform_func
=
transform_func_sr_a
if
is_sr_axis_order
else
lambda
i
,
j
:
transform_func_sr_a
(
j
,
i
)
elif
matrix_is_b
:
transform_func
=
transform_func_sr_b
if
is_sr_axis_order
else
lambda
i
,
j
:
transform_func_sr_b
(
j
,
i
)
transform_func
=
transform_func_sr_b
if
is_sr_axis_order
else
lambda
i
,
j
:
transform_func_sr_b
(
j
,
i
)
else
:
raise
ValueError
(
f
"Unsupported matrix
{
matrix
}
"
)
...
...
@@ -747,7 +725,8 @@ class SparseTensorCoreIntrinEmitter:
return
local_id
base_fragment
=
T
.
Fragment
(
[
micro_size_s
,
micro_size_r
//
2
if
matrix_is_a
else
micro_size_r
]
if
is_sr_axis_order
[
micro_size_s
,
micro_size_r
//
2
if
matrix_is_a
else
micro_size_r
]
if
is_sr_axis_order
else
[
micro_size_r
//
2
if
matrix_is_a
else
micro_size_r
,
micro_size_s
],
forward_thread_fn
=
forward_thread
,
forward_index_fn
=
forward_index
,
...
...
@@ -762,31 +741,19 @@ class SparseTensorCoreIntrinEmitter:
replicate
=
block_col_warps
if
matrix_is_a
else
block_row_warps
if
is_sr_axis_order
:
warp_fragment
=
base_fragment
.
repeat
([
warp_s
,
warp_r
],
repeat_on_thread
=
False
,
lower_dim_first
=
False
)
warp_fragment
=
base_fragment
.
repeat
([
warp_s
,
warp_r
],
repeat_on_thread
=
False
,
lower_dim_first
=
False
)
if
matrix_is_a
:
block_fragment
=
warp_fragment
.
repeat
([
block_s
,
1
],
repeat_on_thread
=
True
,
lower_dim_first
=
True
).
replicate
(
replicate
)
block_fragment
=
warp_fragment
.
repeat
([
block_s
,
1
],
repeat_on_thread
=
True
,
lower_dim_first
=
True
).
replicate
(
replicate
)
elif
matrix_is_b
:
block_fragment
=
warp_fragment
.
replicate
(
replicate
).
repeat
([
block_s
,
1
],
repeat_on_thread
=
True
,
lower_dim_first
=
True
)
block_fragment
=
warp_fragment
.
replicate
(
replicate
).
repeat
([
block_s
,
1
],
repeat_on_thread
=
True
,
lower_dim_first
=
True
)
else
:
raise
ValueError
(
f
"Unsupported matrix type
{
matrix
}
"
)
else
:
warp_fragment
=
base_fragment
.
repeat
([
warp_r
,
warp_s
],
repeat_on_thread
=
False
,
lower_dim_first
=
True
)
warp_fragment
=
base_fragment
.
repeat
([
warp_r
,
warp_s
],
repeat_on_thread
=
False
,
lower_dim_first
=
True
)
if
matrix_is_a
:
block_fragment
=
warp_fragment
.
repeat
([
1
,
block_s
],
repeat_on_thread
=
True
,
lower_dim_first
=
True
).
replicate
(
replicate
)
block_fragment
=
warp_fragment
.
repeat
([
1
,
block_s
],
repeat_on_thread
=
True
,
lower_dim_first
=
True
).
replicate
(
replicate
)
elif
matrix_is_b
:
block_fragment
=
warp_fragment
.
replicate
(
replicate
).
repeat
([
1
,
block_s
],
repeat_on_thread
=
True
,
lower_dim_first
=
True
)
block_fragment
=
warp_fragment
.
replicate
(
replicate
).
repeat
([
1
,
block_s
],
repeat_on_thread
=
True
,
lower_dim_first
=
True
)
else
:
raise
ValueError
(
f
"Unsupported matrix type
{
matrix
}
"
)
...
...
tilelang/intrinsics/tcgen05_macro_generator.py
View file @
29051439
...
...
@@ -88,9 +88,22 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
is_m_first
:
bool
=
False
,
thread_var
:
Var
|
None
=
None
,
):
super
().
__init__
(
a_dtype
,
b_dtype
,
accum_dtype
,
a_transposed
,
b_transposed
,
block_row_warps
,
block_col_warps
,
warp_row_tiles
,
warp_col_tiles
,
chunk
,
reduce_k
,
num_elems_per_byte
,
is_m_first
,
thread_var
)
super
().
__init__
(
a_dtype
,
b_dtype
,
accum_dtype
,
a_transposed
,
b_transposed
,
block_row_warps
,
block_col_warps
,
warp_row_tiles
,
warp_col_tiles
,
chunk
,
reduce_k
,
num_elems_per_byte
,
is_m_first
,
thread_var
,
)
def
_assign_a_shared_layout
(
self
,
layout
:
Layout
):
self
.
a_shared_layout
=
layout
...
...
@@ -137,13 +150,7 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
else
:
raise
ValueError
(
f
"Unsupported swizzle mode:
{
layout
}
"
)
def
tcgen05mma
(
self
,
A_buf
:
Buffer
,
B_buf
:
Buffer
,
C_local_buf
:
Buffer
,
mbar
,
clear_accum
:
PrimExpr
=
False
):
def
tcgen05mma
(
self
,
A_buf
:
Buffer
,
B_buf
:
Buffer
,
C_local_buf
:
Buffer
,
mbar
,
clear_accum
:
PrimExpr
=
False
):
if
is_tensor_memory
(
A_buf
):
return
self
.
tcgen05mma_rs
(
A_buf
,
B_buf
,
C_local_buf
,
clear_accum
)
...
...
@@ -164,22 +171,20 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
elems_in_bits
=
DataType
(
self
.
a_dtype
).
bits
elems_in_bytes
=
elems_in_bits
//
8
a_swizzle_atom_elems
=
a_swizzle_mode
.
swizzle_byte_size
()
//
elems_in_bytes
b_swizzle_atom_elems
=
n_dim
if
b_swizzle_mode
.
is_none
(
)
else
b_swizzle_mode
.
swizzle_byte_size
()
//
elems_in_bytes
b_swizzle_atom_elems
=
n_dim
if
b_swizzle_mode
.
is_none
()
else
b_swizzle_mode
.
swizzle_byte_size
()
//
elems_in_bytes
accum_dtype_in_bits
=
DataType
(
accum_dtype
).
bits
meta
=
self
.
get_tcgen5_mma_meta
(
m_dim
,
n_dim
,
k_dim
)
if
len
(
meta
)
!=
5
:
raise
ValueError
(
f
"Unsupported TCGEN5MMA configuration for desc generation: M=
{
m_dim
}
, N=
{
n_dim
}
, "
f
"K=
{
k_dim
}
, A dtype=
{
self
.
a_dtype
}
, accum dtype=
{
self
.
accum_dtype
}
"
)
f
"K=
{
k_dim
}
, A dtype=
{
self
.
a_dtype
}
, accum dtype=
{
self
.
accum_dtype
}
"
)
atom_m
,
atom_n
,
atom_k
,
enable_ws
,
enable_2cta
=
(
int
(
x
)
for
x
in
meta
)
# by default, we utilize non-swizzle layout offset
a_leading_byte_offset
=
(
8
*
8
*
elems_in_bytes
)
if
a_is_k_major
else
(
8
*
m_dim
*
elems_in_bytes
)
a_stride_byte_offset
=
(
8
*
k_dim
*
elems_in_bytes
)
if
a_is_k_major
else
(
8
*
8
*
elems_in_bytes
)
a_leading_byte_offset
=
(
8
*
8
*
elems_in_bytes
)
if
a_is_k_major
else
(
8
*
m_dim
*
elems_in_bytes
)
a_stride_byte_offset
=
(
8
*
k_dim
*
elems_in_bytes
)
if
a_is_k_major
else
(
8
*
8
*
elems_in_bytes
)
if
not
a_swizzle_mode
.
is_none
():
# swizzle mode doesn't require LBO/SBO to be 1
...
...
@@ -202,11 +207,8 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
else
:
a_stride_byte_offset
=
8
*
elems_in_bytes
*
a_swizzle_atom_elems
b_leading_byte_offset
=
(
8
*
8
*
elems_in_bytes
)
if
b_is_k_major
else
(
8
*
n_dim
*
elems_in_bytes
)
b_stride_byte_offset
=
(
8
*
k_dim
*
elems_in_bytes
)
if
b_is_k_major
else
(
0
if
n_dim
==
8
else
(
8
*
8
*
elems_in_bytes
))
b_leading_byte_offset
=
(
8
*
8
*
elems_in_bytes
)
if
b_is_k_major
else
(
8
*
n_dim
*
elems_in_bytes
)
b_stride_byte_offset
=
(
8
*
k_dim
*
elems_in_bytes
)
if
b_is_k_major
else
(
0
if
n_dim
==
8
else
(
8
*
8
*
elems_in_bytes
))
if
not
b_swizzle_mode
.
is_none
():
# swizzle mode doesn't require LBO/SBO to be 1
# https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-leading-dimension-byte-offset
...
...
@@ -312,21 +314,26 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
for
ki
in
T
.
unroll
(
0
,
(
k_dim
//
micro_size_k
)):
scale_out
=
T
.
Select
(
ki
!=
0
,
1
,
T
.
Select
(
clear_accum
,
0
,
1
))
A_elem_offset
=
(
ki
%
ak_atom_size
)
*
micro_size_k
+
i
*
atom_m
*
a_swizzle_atom_elems
+
(
ki
//
ak_atom_size
)
*
m_dim
*
a_swizzle_atom_elems
if
a_is_k_major
else
i
*
atom_m
*
k_dim
+
ki
*
a_swizzle_atom_elems
*
micro_size_k
(
ki
%
ak_atom_size
)
*
micro_size_k
+
i
*
atom_m
*
a_swizzle_atom_elems
+
(
ki
//
ak_atom_size
)
*
m_dim
*
a_swizzle_atom_elems
if
a_is_k_major
else
i
*
atom_m
*
k_dim
+
ki
*
a_swizzle_atom_elems
*
micro_size_k
)
B_elem_offset
=
(
ki
//
bk_atom_size
)
*
n_dim
*
b_swizzle_atom_elems
+
(
ki
%
bk_atom_size
)
*
micro_size_k
+
j
*
atom_n
*
b_swizzle_atom_elems
if
b_is_k_major
else
(
ki
*
b_swizzle_atom_elems
*
micro_size_k
+
j
*
atom_n
*
(
k_dim
if
n_dim
//
b_swizzle_atom_elems
>
1
else
1
))
B_elem_offset
=
(
(
ki
//
bk_atom_size
)
*
n_dim
*
b_swizzle_atom_elems
+
(
ki
%
bk_atom_size
)
*
micro_size_k
+
j
*
atom_n
*
b_swizzle_atom_elems
if
b_is_k_major
else
(
ki
*
b_swizzle_atom_elems
*
micro_size_k
+
j
*
atom_n
*
(
k_dim
if
n_dim
//
b_swizzle_atom_elems
>
1
else
1
)
)
)
A_byte_offset
=
A_elem_offset
*
elems_in_bytes
B_byte_offset
=
B_elem_offset
*
elems_in_bytes
C_offset
=
(
i
*
n_dim
+
j
*
tmem_col_step
)
*
accum_dtype_in_bits
//
32
# 32 bits per tmem bank
C_offset
=
(
i
*
n_dim
+
j
*
tmem_col_step
)
*
accum_dtype_in_bits
//
32
# 32 bits per tmem bank
T
.
ptx_tcgen05_mma_ss
(
a_dtype_abbrv
,
...
...
@@ -373,8 +380,7 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
"""
assert
is_tensor_memory
(
tmem_buf
),
"tmem_buf must reside in tensor memory (shared.tmem)"
if
len
(
tmem_buf
.
shape
)
!=
2
:
raise
ValueError
(
f
"TCGEN5MMA expects a 2-D tensor-memory buffer, got shape
{
tmem_buf
.
shape
}
"
)
raise
ValueError
(
f
"TCGEN5MMA expects a 2-D tensor-memory buffer, got shape
{
tmem_buf
.
shape
}
"
)
m
=
int
(
tmem_buf
.
shape
[
0
])
n
=
int
(
tmem_buf
.
shape
[
1
])
...
...
@@ -382,14 +388,13 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
meta
=
self
.
get_tcgen5_mma_meta
(
m
,
n
,
k
)
if
len
(
meta
)
!=
5
:
raise
ValueError
(
f
"Unsupported TCGEN5MMA configuration: M=
{
m
}
, N=
{
n
}
, K=
{
k
}
, "
f
"A dtype=
{
self
.
a_dtype
}
, accum dtype=
{
self
.
accum_dtype
}
"
)
raise
ValueError
(
f
"Unsupported TCGEN5MMA configuration: M=
{
m
}
, N=
{
n
}
, K=
{
k
}
, A dtype=
{
self
.
a_dtype
}
, accum dtype=
{
self
.
accum_dtype
}
"
)
atom_m
,
atom_n
,
_
,
_
,
_
=
(
int
(
x
)
for
x
in
meta
)
if
m
%
atom_m
!=
0
or
n
%
atom_n
!=
0
:
raise
ValueError
(
f
"Invalid TCGEN5MMA store layout for shape (
{
m
}
,
{
n
}
) with atoms (
{
atom_m
}
,
{
atom_n
}
)"
)
raise
ValueError
(
f
"Invalid TCGEN5MMA store layout for shape (
{
m
}
,
{
n
}
) with atoms (
{
atom_m
}
,
{
atom_n
}
)"
)
def
forward
(
i
:
PrimExpr
,
j
:
PrimExpr
):
atom_idx
=
(
i
//
atom_m
)
+
(
j
//
atom_n
)
*
(
m
//
atom_m
)
...
...
@@ -422,11 +427,11 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
return
Layout
([
m
,
n
],
forward
)
def
get_tcgen5_mma_meta
(
self
,
m
:
int
,
n
:
int
,
k
:
int
):
return
_ffi_api
.
get_tcgen5_mma_meta
(
int
(
m
),
int
(
n
),
int
(
k
),
DataType
(
self
.
a_dtype
),
DataType
(
self
.
accum_dtype
))
return
_ffi_api
.
get_tcgen5_mma_meta
(
int
(
m
),
int
(
n
),
int
(
k
),
DataType
(
self
.
a_dtype
),
DataType
(
self
.
accum_dtype
))
def
get_tcgen5_instr_desc
(
self
,
atom_m
:
int
,
atom_n
:
int
,
atom_k
:
int
,
a_is_k_major
:
bool
,
b_is_k_major
:
bool
,
scale_in_a
:
int
,
scale_in_b
:
int
)
->
PrimExpr
:
def
get_tcgen5_instr_desc
(
self
,
atom_m
:
int
,
atom_n
:
int
,
atom_k
:
int
,
a_is_k_major
:
bool
,
b_is_k_major
:
bool
,
scale_in_a
:
int
,
scale_in_b
:
int
)
->
PrimExpr
:
desc
=
_ffi_api
.
get_tcgen5_instr_desc
(
atom_m
,
atom_n
,
...
...
tilelang/intrinsics/utils.py
View file @
29051439
...
...
@@ -10,7 +10,7 @@ from .mma_layout import (
mma_store_32x8_to_shared_16x16_layout
,
mma_store_32x2_to_shared_8x8_layout_fp64
,
)
from
.mfma_layout
import
(
thread_id_shared_access_64x4_to_16x16_layout_C_n_m
)
from
.mfma_layout
import
thread_id_shared_access_64x4_to_16x16_layout_C_n_m
from
.mma_layout
import
get_swizzle_layout
# noqa: F401
from
.mma_layout
import
make_mma_swizzle_layout
# noqa: F401
...
...
tilelang/intrinsics/wgmma_macro_generator.py
View file @
29051439
...
...
@@ -15,9 +15,11 @@ from tilelang.layout import (
make_linear_layout
,
)
from
tvm.runtime
import
convert
from
tilelang.intrinsics.mma_layout
import
(
shared_16x8_to_mma_32x4_layout_sr_a
,
shared_16x16_to_mma_32x8_layout_sr_a
,
shared_16x32_to_mma_32x16_layout_sr_a
)
from
tilelang.intrinsics.mma_layout
import
(
shared_16x8_to_mma_32x4_layout_sr_a
,
shared_16x16_to_mma_32x8_layout_sr_a
,
shared_16x32_to_mma_32x16_layout_sr_a
,
)
lift
=
convert
...
...
@@ -96,9 +98,22 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
is_m_first
:
bool
|
None
=
False
,
thread_var
:
Var
|
None
=
None
,
):
super
().
__init__
(
a_dtype
,
b_dtype
,
accum_dtype
,
a_transposed
,
b_transposed
,
block_row_warps
,
block_col_warps
,
warp_row_tiles
,
warp_col_tiles
,
chunk
,
reduce_k
,
num_elems_per_byte
,
is_m_first
,
thread_var
)
super
().
__init__
(
a_dtype
,
b_dtype
,
accum_dtype
,
a_transposed
,
b_transposed
,
block_row_warps
,
block_col_warps
,
warp_row_tiles
,
warp_col_tiles
,
chunk
,
reduce_k
,
num_elems_per_byte
,
is_m_first
,
thread_var
,
)
self
.
_initialize_wgmma_prefix
(
self
.
n_dim
)
def
_assign_a_shared_layout
(
self
,
layout
:
Layout
):
...
...
@@ -112,12 +127,12 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
def
_initialize_wgmma_prefix
(
self
,
n_dim
:
int
=
16
):
inst_m
,
inst_n
=
64
,
gcd
(
self
.
warp_col_tiles
,
256
)
assert
inst_n
%
8
==
0
,
(
f
"inst_n must be a multiple of 8, got
{
inst_n
}
"
f
"(block_col_warps=
{
self
.
block_col_warps
}
, warp_col_tiles=
{
self
.
warp_col_tiles
}
)"
)
f
"inst_n must be a multiple of 8, got
{
inst_n
}
(block_col_warps=
{
self
.
block_col_warps
}
, warp_col_tiles=
{
self
.
warp_col_tiles
}
)
"
)
# Validate inst_n: Hopper WGMMA supports n in [8, 256] and multiple of 8
assert
8
<=
inst_n
<=
256
,
(
f
"inst_n must be within [8, 256], got
{
inst_n
}
"
f
"(block_col_warps=
{
self
.
block_col_warps
}
, warp_col_tiles=
{
self
.
warp_col_tiles
}
)"
)
f
"inst_n must be within [8, 256], got
{
inst_n
}
(block_col_warps=
{
self
.
block_col_warps
}
, warp_col_tiles=
{
self
.
warp_col_tiles
}
)
"
)
# 256 bits per instruction
inst_k
=
256
//
DataType
(
self
.
a_dtype
).
bits
self
.
wgmma_inst_m
=
inst_m
...
...
@@ -160,13 +175,9 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
else
:
raise
ValueError
(
f
"Unsupported swizzle mode:
{
layout
}
"
)
def
wgmma
(
self
,
A_region
:
BufferRegion
,
B_region
:
BufferRegion
,
C_region
:
BufferRegion
,
clear_accum
:
PrimExpr
=
False
,
wg_wait
:
int
=
0
):
def
wgmma
(
self
,
A_region
:
BufferRegion
,
B_region
:
BufferRegion
,
C_region
:
BufferRegion
,
clear_accum
:
PrimExpr
=
False
,
wg_wait
:
int
=
0
):
if
is_fragment
(
A_region
):
return
self
.
wgmma_rs
(
A_region
,
B_region
,
C_region
,
clear_accum
,
wg_wait
)
...
...
@@ -195,16 +206,13 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
elems_in_bytes
=
elems_in_bits
//
8
a_swizzle_atom_elems
=
a_swizzle_mode
.
swizzle_byte_size
()
//
elems_in_bytes
b_swizzle_atom_elems
=
n_dim
if
b_swizzle_mode
.
is_none
(
)
else
b_swizzle_mode
.
swizzle_byte_size
()
//
elems_in_bytes
b_swizzle_atom_elems
=
n_dim
if
b_swizzle_mode
.
is_none
()
else
b_swizzle_mode
.
swizzle_byte_size
()
//
elems_in_bytes
accum_bits
=
DataType
(
accum_dtype
).
bits
accum_regs
=
((
m_dim
//
64
)
*
warp_cols
*
local_size_out
*
accum_bits
+
31
)
//
32
# by default, we utilize non-swizzle layout offset
a_leading_byte_offset
=
(
8
*
8
*
elems_in_bytes
)
if
a_is_k_major
else
(
8
*
m_dim
*
elems_in_bytes
)
a_stride_byte_offset
=
(
8
*
k_dim
*
elems_in_bytes
)
if
a_is_k_major
else
(
8
*
8
*
elems_in_bytes
)
a_leading_byte_offset
=
(
8
*
8
*
elems_in_bytes
)
if
a_is_k_major
else
(
8
*
m_dim
*
elems_in_bytes
)
a_stride_byte_offset
=
(
8
*
k_dim
*
elems_in_bytes
)
if
a_is_k_major
else
(
8
*
8
*
elems_in_bytes
)
if
not
a_swizzle_mode
.
is_none
():
# swizzle mode doesn't require LBO/SBO to be 1
...
...
@@ -220,19 +228,15 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
if
a_m_axis_atoms
<=
1
:
a_leading_byte_offset
=
0
else
:
a_leading_byte_offset
=
8
*
a_swizzle_mode
.
swizzle_atom_size
()
*
(
a_swizzle_mode
.
swizzle_byte_size
()
//
elems_in_bytes
)
a_leading_byte_offset
=
8
*
a_swizzle_mode
.
swizzle_atom_size
()
*
(
a_swizzle_mode
.
swizzle_byte_size
()
//
elems_in_bytes
)
if
a_m_axis_atoms
<=
1
:
a_stride_byte_offset
=
8
*
elems_in_bytes
*
m_dim
else
:
a_stride_byte_offset
=
8
*
elems_in_bytes
*
a_swizzle_atom_elems
b_leading_byte_offset
=
(
8
*
8
*
elems_in_bytes
)
if
b_is_k_major
else
(
8
*
n_dim
*
elems_in_bytes
)
b_stride_byte_offset
=
(
8
*
k_dim
*
elems_in_bytes
)
if
b_is_k_major
else
(
0
if
n_dim
==
8
else
(
8
*
8
*
elems_in_bytes
))
b_leading_byte_offset
=
(
8
*
8
*
elems_in_bytes
)
if
b_is_k_major
else
(
8
*
n_dim
*
elems_in_bytes
)
b_stride_byte_offset
=
(
8
*
k_dim
*
elems_in_bytes
)
if
b_is_k_major
else
(
0
if
n_dim
==
8
else
(
8
*
8
*
elems_in_bytes
))
if
not
b_swizzle_mode
.
is_none
():
# swizzle mode doesn't require LBO/SBO to be 1
# https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-leading-dimension-byte-offset
...
...
@@ -275,12 +279,8 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
desc_a
=
T
.
alloc_wgmma_desc
()
desc_b
=
T
.
alloc_wgmma_desc
()
T
.
initialize_wgmma_descriptor
(
desc_a
,
A_ptr
,
a_swizzle_mode
,
int
(
a_leading_byte_offset
>>
4
),
int
(
a_stride_byte_offset
>>
4
))
T
.
initialize_wgmma_descriptor
(
desc_b
,
B_ptr
,
b_swizzle_mode
,
int
(
b_leading_byte_offset
>>
4
),
int
(
b_stride_byte_offset
>>
4
))
T
.
initialize_wgmma_descriptor
(
desc_a
,
A_ptr
,
a_swizzle_mode
,
int
(
a_leading_byte_offset
>>
4
),
int
(
a_stride_byte_offset
>>
4
))
T
.
initialize_wgmma_descriptor
(
desc_b
,
B_ptr
,
b_swizzle_mode
,
int
(
b_leading_byte_offset
>>
4
),
int
(
b_stride_byte_offset
>>
4
))
T
.
warpgroup_fence_operand
(
C_buf
,
num_regs
=
accum_regs
)
T
.
warpgroup_arrive
()
...
...
@@ -291,21 +291,41 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
warp_i
=
(
warp_m
//
4
)
*
num_inst_m
+
i
warp_j
=
warp_n
*
num_inst_n
+
j
A_offset
=
(
ki
%
ak_atom_size
)
*
micro_size_k
+
warp_i
*
64
*
a_swizzle_atom_elems
+
(
ki
//
ak_atom_size
)
*
m_dim
*
a_swizzle_atom_elems
if
a_is_k_major
else
warp_i
*
64
*
k_dim
+
ki
*
a_swizzle_atom_elems
*
micro_size_k
B_offset
=
(
ki
//
bk_atom_size
)
*
n_dim
*
b_swizzle_atom_elems
+
(
ki
%
bk_atom_size
)
*
micro_size_k
+
warp_j
*
wgmma_inst_n
*
b_swizzle_atom_elems
if
b_is_k_major
else
(
ki
*
b_swizzle_atom_elems
*
micro_size_k
+
warp_j
*
wgmma_inst_n
*
(
k_dim
if
n_dim
//
b_swizzle_atom_elems
>
1
else
1
))
(
ki
%
ak_atom_size
)
*
micro_size_k
+
warp_i
*
64
*
a_swizzle_atom_elems
+
(
ki
//
ak_atom_size
)
*
m_dim
*
a_swizzle_atom_elems
if
a_is_k_major
else
warp_i
*
64
*
k_dim
+
ki
*
a_swizzle_atom_elems
*
micro_size_k
)
B_offset
=
(
(
ki
//
bk_atom_size
)
*
n_dim
*
b_swizzle_atom_elems
+
(
ki
%
bk_atom_size
)
*
micro_size_k
+
warp_j
*
wgmma_inst_n
*
b_swizzle_atom_elems
if
b_is_k_major
else
(
ki
*
b_swizzle_atom_elems
*
micro_size_k
+
warp_j
*
wgmma_inst_n
*
(
k_dim
if
n_dim
//
b_swizzle_atom_elems
>
1
else
1
)
)
)
C_offset
=
i
*
warp_cols
*
local_size_out
+
j
*
warp_cols
*
local_size_out
//
num_inst_n
# 4 warps as an unit
T
.
ptx_wgmma_ss
(
accum_dtype
,
wgmma_prefix
,
a_is_k_major
,
b_is_k_major
,
a_dtype_abbrv
,
b_dtype_abbrv
,
accum_dtype_abbrv
,
desc_a
.
data
,
(
A_offset
*
elems_in_bytes
)
>>
4
,
desc_b
.
data
,
(
B_offset
*
elems_in_bytes
)
>>
4
,
C_buf
.
data
,
C_offset
,
scale_out
,
scale_in_a
,
scale_in_b
)
T
.
ptx_wgmma_ss
(
accum_dtype
,
wgmma_prefix
,
a_is_k_major
,
b_is_k_major
,
a_dtype_abbrv
,
b_dtype_abbrv
,
accum_dtype_abbrv
,
desc_a
.
data
,
(
A_offset
*
elems_in_bytes
)
>>
4
,
desc_b
.
data
,
(
B_offset
*
elems_in_bytes
)
>>
4
,
C_buf
.
data
,
C_offset
,
scale_out
,
scale_in_a
,
scale_in_b
,
)
T
.
warpgroup_commit_batch
()
if
wg_wait
>=
0
:
...
...
@@ -314,12 +334,9 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
return
_warp_mma
(
A_ptr
,
B_ptr
,
C_buf
)
def
wgmma_rs
(
self
,
A_region
:
BufferRegion
,
B_region
:
BufferRegion
,
C_region
:
BufferRegion
,
clear_accum
:
PrimExpr
=
False
,
wg_wait
:
int
=
0
):
def
wgmma_rs
(
self
,
A_region
:
BufferRegion
,
B_region
:
BufferRegion
,
C_region
:
BufferRegion
,
clear_accum
:
PrimExpr
=
False
,
wg_wait
:
int
=
0
):
local_size_a
=
self
.
local_size_a
local_size_out
=
self
.
local_size_out
a_dtype_abbrv
=
self
.
a_dtype_abbrv
...
...
@@ -344,14 +361,10 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
b_is_k_major
=
self
.
b_transposed
b_swizzle_mode
=
self
.
_determinate_swizzle_mode
(
B_region
,
self
.
b_shared_layout
)
b_swizzle_atom_elems
=
n_dim
if
b_swizzle_mode
.
is_none
(
)
else
b_swizzle_mode
.
swizzle_byte_size
()
//
elems_in_bytes
b_leading_byte_offset
=
(
8
*
8
*
elems_in_bytes
)
if
b_is_k_major
else
(
8
*
n_dim
*
elems_in_bytes
)
b_stride_byte_offset
=
(
8
*
k_dim
*
elems_in_bytes
)
if
b_is_k_major
else
(
0
if
n_dim
==
8
else
(
8
*
8
*
elems_in_bytes
))
b_swizzle_atom_elems
=
n_dim
if
b_swizzle_mode
.
is_none
()
else
b_swizzle_mode
.
swizzle_byte_size
()
//
elems_in_bytes
b_leading_byte_offset
=
(
8
*
8
*
elems_in_bytes
)
if
b_is_k_major
else
(
8
*
n_dim
*
elems_in_bytes
)
b_stride_byte_offset
=
(
8
*
k_dim
*
elems_in_bytes
)
if
b_is_k_major
else
(
0
if
n_dim
==
8
else
(
8
*
8
*
elems_in_bytes
))
if
not
b_swizzle_mode
.
is_none
():
# swizzle mode doesn't require LBO/SBO to be 1
# https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-leading-dimension-byte-offset
...
...
@@ -390,9 +403,7 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
tx
,
warp_n
,
warp_m
=
self
.
extract_thread_binding
(
thread_binding
)
desc_b
=
T
.
alloc_wgmma_desc
()
T
.
initialize_wgmma_descriptor
(
desc_b
,
B_ptr
,
b_swizzle_mode
,
int
(
b_leading_byte_offset
>>
4
),
int
(
b_stride_byte_offset
>>
4
))
T
.
initialize_wgmma_descriptor
(
desc_b
,
B_ptr
,
b_swizzle_mode
,
int
(
b_leading_byte_offset
>>
4
),
int
(
b_stride_byte_offset
>>
4
))
T
.
warpgroup_fence_operand
(
A_buf
,
num_regs
=
a_regs
)
T
.
warpgroup_fence_operand
(
C_buf
,
num_regs
=
accum_regs
)
T
.
warpgroup_arrive
()
...
...
@@ -405,11 +416,15 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
A_offset
=
ki
*
warp_rows
*
local_size_a
+
i
*
local_size_a
B_offset
=
(
ki
//
bk_atom_size
)
*
n_dim
*
b_swizzle_atom_elems
+
warp_j
*
wgmma_inst_n
*
b_swizzle_atom_elems
+
(
ki
%
bk_atom_size
)
*
micro_size_k
if
b_is_k_major
else
(
ki
*
b_swizzle_atom_elems
*
micro_size_k
+
warp_j
*
wgmma_inst_n
*
(
k_dim
if
n_dim
//
b_swizzle_atom_elems
>
1
else
1
))
(
ki
//
bk_atom_size
)
*
n_dim
*
b_swizzle_atom_elems
+
warp_j
*
wgmma_inst_n
*
b_swizzle_atom_elems
+
(
ki
%
bk_atom_size
)
*
micro_size_k
if
b_is_k_major
else
(
ki
*
b_swizzle_atom_elems
*
micro_size_k
+
warp_j
*
wgmma_inst_n
*
(
k_dim
if
n_dim
//
b_swizzle_atom_elems
>
1
else
1
)
)
)
C_offset
=
i
*
warp_cols
*
local_size_out
+
j
*
warp_cols
*
local_size_out
//
num_inst_n
# 4 warps as an unit
T
.
ptx_wgmma_rs
(
accum_dtype
,
...
...
@@ -460,6 +475,7 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
If `local_buf` is not detected to be a fragment buffer.
"""
from
tilelang.utils
import
is_fragment
assert
matrix
in
[
"A"
],
"matrix should be A for WGMMA"
dtype
=
self
.
a_dtype
dtype_bits
=
DataType
(
dtype
).
bits
...
...
@@ -488,8 +504,7 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
# the layout of mma.sync is row.col.
# so the b matrix expected a transposed basic layout
transform_func
:
Callable
=
None
transform_func
=
transform_func_sr_a
if
is_sr_axis_order
else
lambda
i
,
j
:
transform_func_sr_a
(
j
,
i
)
transform_func
=
transform_func_sr_a
if
is_sr_axis_order
else
lambda
i
,
j
:
transform_func_sr_a
(
j
,
i
)
assert
is_fragment
(
local_buf
),
f
"local_buf must be a fragment, but got
{
local_buf
.
scope
()
}
"
...
...
@@ -531,20 +546,12 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
replicate
=
block_col_warps
if
is_sr_axis_order
:
warp_fragment
=
base_fragment
.
repeat
([
block_s
,
1
],
repeat_on_thread
=
True
,
lower_dim_first
=
False
).
replicate
(
replicate
)
block_fragment
=
warp_fragment
.
repeat
([
warp_s
,
warp_r
],
repeat_on_thread
=
False
,
lower_dim_first
=
False
)
warp_fragment
=
base_fragment
.
repeat
([
block_s
,
1
],
repeat_on_thread
=
True
,
lower_dim_first
=
False
).
replicate
(
replicate
)
block_fragment
=
warp_fragment
.
repeat
([
warp_s
,
warp_r
],
repeat_on_thread
=
False
,
lower_dim_first
=
False
)
else
:
# rs condition, transposed_a matrix
warp_fragment
=
base_fragment
.
repeat
([
1
,
block_s
],
repeat_on_thread
=
True
,
lower_dim_first
=
False
).
replicate
(
replicate
)
block_fragment
=
warp_fragment
.
repeat
([
warp_r
,
warp_s
],
repeat_on_thread
=
False
,
lower_dim_first
=
True
)
warp_fragment
=
base_fragment
.
repeat
([
1
,
block_s
],
repeat_on_thread
=
True
,
lower_dim_first
=
False
).
replicate
(
replicate
)
block_fragment
=
warp_fragment
.
repeat
([
warp_r
,
warp_s
],
repeat_on_thread
=
False
,
lower_dim_first
=
True
)
return
block_fragment
...
...
tilelang/ir.py
View file @
29051439
...
...
@@ -7,23 +7,19 @@ from tilelang import _ffi_api
@
tvm_ffi
.
register_object
(
"tl.Fill"
)
class
Fill
(
Node
,
Scriptable
):
...
class
Fill
(
Node
,
Scriptable
):
...
@
tvm_ffi
.
register_object
(
"tl.AtomicAdd"
)
class
AtomicAdd
(
Node
,
Scriptable
):
...
class
AtomicAdd
(
Node
,
Scriptable
):
...
@
tvm_ffi
.
register_object
(
"tl.Copy"
)
class
Copy
(
Node
,
Scriptable
):
...
class
Copy
(
Node
,
Scriptable
):
...
@
tvm_ffi
.
register_object
(
"tl.Conv2DIm2Col"
)
class
Conv2DIm2ColOp
(
Node
,
Scriptable
):
...
class
Conv2DIm2ColOp
(
Node
,
Scriptable
):
...
@
tvm_ffi
.
register_object
(
"tl.GemmWarpPolicy"
)
...
...
@@ -32,10 +28,8 @@ class GemmWarpPolicy(Node, Scriptable):
m_warp
:
int
n_warp
:
int
def
compute_warp_partition
(
self
,
M
:
int
,
N
:
int
,
block_size
:
int
,
target
:
Target
,
is_wgmma
:
bool
):
_ffi_api
.
GemmWarpPolicyComputeWarpPartition
(
self
,
int
(
M
),
int
(
N
),
int
(
block_size
),
target
,
is_wgmma
)
def
compute_warp_partition
(
self
,
M
:
int
,
N
:
int
,
block_size
:
int
,
target
:
Target
,
is_wgmma
:
bool
):
_ffi_api
.
GemmWarpPolicyComputeWarpPartition
(
self
,
int
(
M
),
int
(
N
),
int
(
block_size
),
target
,
is_wgmma
)
return
self
.
m_warp
,
self
.
n_warp
...
...
@@ -45,48 +39,38 @@ class GemmSPWarpPolicy(Node, Scriptable):
m_warp
:
int
n_warp
:
int
def
compute_warp_partition
(
self
,
M
:
int
,
N
:
int
,
block_size
:
int
,
target
:
Target
,
is_wgmma
:
bool
,
bits
:
int
):
_ffi_api
.
GemmSPWarpPolicyComputeWarpPartition
(
self
,
int
(
M
),
int
(
N
),
int
(
block_size
),
target
,
is_wgmma
,
bits
)
def
compute_warp_partition
(
self
,
M
:
int
,
N
:
int
,
block_size
:
int
,
target
:
Target
,
is_wgmma
:
bool
,
bits
:
int
):
_ffi_api
.
GemmSPWarpPolicyComputeWarpPartition
(
self
,
int
(
M
),
int
(
N
),
int
(
block_size
),
target
,
is_wgmma
,
bits
)
return
self
.
m_warp
,
self
.
n_warp
@
tvm_ffi
.
register_object
(
"tl.Gemm"
)
class
Gemm
(
Node
,
Scriptable
):
...
class
Gemm
(
Node
,
Scriptable
):
...
@
tvm_ffi
.
register_object
(
"tl.GemmSP"
)
class
GemmSP
(
Node
,
Scriptable
):
...
class
GemmSP
(
Node
,
Scriptable
):
...
@
tvm_ffi
.
register_object
(
"tl.FinalizeReducerOp"
)
class
FinalizeReducerOp
(
Node
,
Scriptable
):
...
class
FinalizeReducerOp
(
Node
,
Scriptable
):
...
@
tvm_ffi
.
register_object
(
"tl.ParallelOp"
)
class
ParallelOp
(
Node
,
Scriptable
):
...
class
ParallelOp
(
Node
,
Scriptable
):
...
@
tvm_ffi
.
register_object
(
"tl.ReduceOp"
)
class
ReduceOp
(
Node
,
Scriptable
):
...
class
ReduceOp
(
Node
,
Scriptable
):
...
@
tvm_ffi
.
register_object
(
"tl.CumSumOp"
)
class
CumSumOp
(
Node
,
Scriptable
):
...
class
CumSumOp
(
Node
,
Scriptable
):
...
@
tvm_ffi
.
register_object
(
"tl.RegionOp"
)
class
RegionOp
(
Node
,
Scriptable
):
...
class
RegionOp
(
Node
,
Scriptable
):
...
@
tvm_ffi
.
register_object
(
"tl.ReduceType"
)
class
ReduceType
(
Node
,
Scriptable
):
...
class
ReduceType
(
Node
,
Scriptable
):
...
Prev
1
…
15
16
17
18
19
20
21
22
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment