Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
bbbf4207
Unverified
Commit
bbbf4207
authored
Nov 14, 2025
by
guchaoyang
Committed by
GitHub
Nov 14, 2025
Browse files
Merge branch 'main' into dcu
parents
8f4628e0
5eb30a4f
Changes
286
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1958 additions
and
357 deletions
+1958
-357
tilelang/autotuner/tuner.py
tilelang/autotuner/tuner.py
+83
-143
tilelang/carver/roller/policy/default.py
tilelang/carver/roller/policy/default.py
+1
-1
tilelang/carver/roller/shape_inference/tir.py
tilelang/carver/roller/shape_inference/tir.py
+1
-1
tilelang/contrib/cc.py
tilelang/contrib/cc.py
+1
-1
tilelang/contrib/dlpack.py
tilelang/contrib/dlpack.py
+3
-3
tilelang/contrib/hipcc.py
tilelang/contrib/hipcc.py
+2
-2
tilelang/contrib/nvcc.py
tilelang/contrib/nvcc.py
+160
-9
tilelang/contrib/rocm.py
tilelang/contrib/rocm.py
+8
-5
tilelang/engine/callback.py
tilelang/engine/callback.py
+3
-3
tilelang/engine/lower.py
tilelang/engine/lower.py
+6
-5
tilelang/engine/phase.py
tilelang/engine/phase.py
+2
-2
tilelang/env.py
tilelang/env.py
+15
-4
tilelang/intrinsics/mfma_macro_generator.py
tilelang/intrinsics/mfma_macro_generator.py
+314
-66
tilelang/intrinsics/mma_layout.py
tilelang/intrinsics/mma_layout.py
+6
-0
tilelang/intrinsics/mma_macro_generator.py
tilelang/intrinsics/mma_macro_generator.py
+166
-31
tilelang/intrinsics/mma_sm70_layout.py
tilelang/intrinsics/mma_sm70_layout.py
+51
-0
tilelang/intrinsics/mma_sm70_macro_generator.py
tilelang/intrinsics/mma_sm70_macro_generator.py
+528
-0
tilelang/intrinsics/tcgen05_macro_generator.py
tilelang/intrinsics/tcgen05_macro_generator.py
+441
-0
tilelang/intrinsics/utils.py
tilelang/intrinsics/utils.py
+5
-0
tilelang/intrinsics/wgmma_macro_generator.py
tilelang/intrinsics/wgmma_macro_generator.py
+162
-81
No files found.
tilelang/autotuner/tuner.py
View file @
bbbf4207
...
...
@@ -4,17 +4,24 @@ This module provides functionality for auto-tuning tilelang programs, including
and performance optimization through configuration search.
"""
from
__future__
import
annotations
from
dataclasses
import
dataclass
import
tilelang
from
tilelang
import
tvm
as
tvm
from
tilelang.jit
import
JITImpl
from
tilelang.jit.kernel
import
JITKernel
from
tvm.tir
import
PrimFunc
,
Var
from
tvm.target
import
Target
import
inspect
from
functools
import
partial
from
typing
import
(
Callable
,
Literal
,
Any
,
overload
)
from
tqdm
import
tqdm
from
typing
import
(
Callable
,
Generic
,
Literal
,
Any
,
TypeVar
)
# Python 3.9 compatibility for ParamSpec
try
:
from
typing
import
ParamSpec
except
ImportError
:
# Python < 3.10
from
typing_extensions
import
ParamSpec
from
tqdm.auto
import
tqdm
import
logging
import
functools
import
concurrent.futures
import
torch
import
os
...
...
@@ -30,7 +37,6 @@ from tilelang import env
from
tilelang.autotuner.param
import
CompileArgs
,
ProfileArgs
,
AutotuneResult
from
tilelang.autotuner.capture
import
get_autotune_inputs
from
tilelang.utils.target
import
determine_target
from
tilelang.jit.param
import
_P
,
_RProg
from
tilelang
import
__version__
...
...
@@ -524,12 +530,12 @@ class AutoTuner:
# latency, ref_latency = target_fn(jit_kernel)
latency
,
ref_latency
=
run_with_timeout
(
target_fn
,
timeout
,
jit_kernel
)
except
TimeoutException
:
logger
.
info
(
logger
.
warning
(
f
"A timeout occurred while testing config
{
config
}
, checkout autotuner.log for more details"
)
continue
except
Exception
:
logger
.
info
(
logger
.
warning
(
f
"An error occurred while testing config
{
config
}
, checkout autotuner.log for more details"
)
logger
.
debug
(
f
"Error:
{
traceback
.
format_exc
()
}
"
)
...
...
@@ -585,9 +591,13 @@ class AutoTuner:
return
self
.
run
()
class
_AutoTunerImplementation
:
# Overload __init__ to help type checkers understand the effect of return_program
# The '-> None' is for __init__ itself. The crucial part is Literal for return_program.
_P
=
ParamSpec
(
'_P'
)
_T
=
TypeVar
(
'_T'
)
@
dataclass
class
AutoTuneImpl
(
Generic
[
_P
,
_T
]):
jit_impl
:
JITImpl
warmup
:
int
=
25
rep
:
int
=
100
...
...
@@ -603,91 +613,12 @@ class _AutoTunerImplementation:
manual_check_prog
:
Callable
=
None
cache_input_tensors
:
bool
=
False
def
__init__
(
self
,
configs
:
dict
|
Callable
,
warmup
:
int
=
25
,
rep
:
int
=
100
,
timeout
:
int
=
100
,
supply_type
:
tilelang
.
TensorSupplyType
=
tilelang
.
TensorSupplyType
.
Auto
,
ref_prog
:
Callable
=
None
,
supply_prog
:
Callable
=
None
,
rtol
:
float
=
1e-2
,
atol
:
float
=
1e-2
,
max_mismatched_ratio
:
float
=
0.01
,
skip_check
:
bool
=
False
,
manual_check_prog
:
Callable
=
None
,
cache_input_tensors
:
bool
=
False
)
->
None
:
"""Initialize the AutoTunerImplementation.
Args:
configs: Configuration space to explore during auto-tuning.
warmup: Number of warmup iterations before timing.
rep: Number of repetitions for timing measurements.
timeout: Maximum time (in seconds) allowed for each configuration.
supply_type: Strategy for generating input tensors (random/zeros/etc)
ref_prog: Reference implementation for validation
supply_prog: Custom function to provide input tensors
rtol: Relative tolerance for numerical validation
atol: Absolute tolerance for numerical validation
max_mismatched_ratio: Allowed percentage of mismatched values
skip_check: Bypass validation against reference implementation
manual_check_prog: Custom validation function
cache_input_tensors: Reuse input tensors across trials
"""
# Configuration and benchmarking parameters
self
.
configs
=
configs
# Search space of tuning configurations
self
.
warmup
=
warmup
# Warmup iterations for stable measurements
self
.
rep
=
rep
# Measurement repetitions for statistics
self
.
timeout
=
timeout
# Per-configuration timeout threshold
# Tensor handling and validation setup
self
.
supply_type
=
supply_type
# Input tensor generation strategy
self
.
ref_prog
=
ref_prog
# Ground truth implementation
self
.
supply_prog
=
supply_prog
# Custom input data provider
self
.
rtol
=
rtol
# Relative error tolerance
self
.
atol
=
atol
# Absolute error tolerance
self
.
max_mismatched_ratio
=
max_mismatched_ratio
# Allowed mismatch
# Validation control flags
self
.
skip_check
=
skip_check
# Bypass accuracy verification
self
.
manual_check_prog
=
manual_check_prog
# Custom validation
self
.
cache_input_tensors
=
cache_input_tensors
# Reuse inputs
# Cache for storing tuned kernel implementations
self
.
_tuner_cache
:
dict
[
tuple
,
tilelang
.
JITKernel
]
=
{}
# (args, kwargs) -> compiled kernel
# This tells the type checker what the *wrapper* function will return.
# this is for linting, please do not remove it.
@
overload
def
__call__
(
self
,
fn
:
Callable
[
_P
,
_RProg
])
->
Callable
[
_P
,
tuple
[
_RProg
,
AutotuneResult
]]:
...
@
overload
def
__call__
(
self
,
fn
:
Callable
[
_P
,
_RProg
])
->
Callable
[
_P
,
AutotuneResult
]:
...
# Actual implementation of __call__
def
__call__
(
self
,
fn
:
Callable
[
_P
,
_RProg
])
->
Callable
[
_P
,
Any
]:
warmup
=
self
.
warmup
rep
=
self
.
rep
timeout
=
self
.
timeout
@
functools
.
wraps
(
fn
)
def
wrapper
(
*
args
,
**
kwargs
):
key_args_tuple
=
args
key_kwargs_tuple
=
tuple
(
sorted
(
kwargs
.
items
()))
key
=
(
key_args_tuple
,
key_kwargs_tuple
)
if
key
not
in
self
.
_tuner_cache
:
def
jit_compile
(
**
config_arg
):
return
fn
(
*
args
,
**
kwargs
,
__tune_params
=
config_arg
)
compile_arguments
=
fn
(
__return_compile_arguments
=
True
)
def
__post_init__
(
self
):
self
.
_tuner_cache
=
{}
def
get_tunner
(
self
):
autotuner
=
AutoTuner
(
fn
,
configs
=
self
.
configs
).
set_profile_args
(
self
.
jit_impl
.
func
,
configs
=
self
.
configs
).
set_profile_args
(
supply_type
=
self
.
supply_type
,
ref_prog
=
self
.
ref_prog
,
supply_prog
=
self
.
supply_prog
,
...
...
@@ -698,30 +629,35 @@ class _AutoTunerImplementation:
manual_check_prog
=
self
.
manual_check_prog
,
cache_input_tensors
=
self
.
cache_input_tensors
,
).
set_compile_args
(
out_idx
=
compile_arguments
[
'
out_idx
'
]
,
execution_backend
=
compile_arguments
[
'
execution_backend
'
]
,
target
=
compile_arguments
[
'
target
'
]
,
target_host
=
compile_arguments
[
'
target_host
'
]
,
verbose
=
compile_arguments
[
'
verbose
'
]
,
pass_configs
=
compile_arguments
[
'
pass_configs
'
]
,
out_idx
=
self
.
jit_impl
.
out_idx
,
execution_backend
=
self
.
jit_impl
.
execution_backend
,
target
=
self
.
jit_impl
.
target
,
target_host
=
self
.
jit_impl
.
target_host
,
verbose
=
self
.
jit_impl
.
verbose
,
pass_configs
=
self
.
jit_impl
.
pass_configs
,
)
autotuner
.
run
=
partial
(
autotuner
.
run
,
self
.
warmup
,
self
.
rep
,
self
.
timeout
)
return
autotuner
autotuner
.
jit_compile
=
jit_compile
autotuner
.
set_kernel_parameters
(
key
,
inspect
.
signature
(
fn
).
parameters
)
def
__call__
(
self
,
*
args
:
_P
.
args
,
**
kwargs
:
_P
.
kwargs
)
->
JITKernel
:
key_args_tuple
=
args
key_kwargs_tuple
=
tuple
(
sorted
(
kwargs
.
items
()))
key
=
(
key_args_tuple
,
key_kwargs_tuple
)
if
key
not
in
self
.
_tuner_cache
:
autotuner
.
run
=
partial
(
autotuner
.
run
,
warmup
,
rep
,
timeout
)
def
jit_compile
(
**
config_arg
):
return
self
.
jit_impl
(
*
args
,
**
kwargs
,
__tune_params
=
config_arg
)
autotuner
=
self
.
get_tunner
()
autotuner
.
jit_compile
=
jit_compile
autotuner
.
set_kernel_parameters
(
key
,
self
.
jit_impl
.
signature
.
parameters
)
artifact
=
autotuner
.
run
()
self
.
_tuner_cache
[
key
]
=
artifact
.
kernel
return
self
.
_tuner_cache
[
key
]
return
wrapper
def
autotune
(
# This is the new public interface
func
:
Callable
[
_P
,
_
RProg
]
|
PrimFunc
|
None
=
None
,
func
:
Callable
[
_P
,
_
T
]
|
PrimFunc
|
None
=
None
,
*
,
# Indicates subsequent arguments are keyword-only
configs
:
dict
|
Callable
,
# profile arguments
...
...
@@ -795,10 +731,13 @@ def autotune( # This is the new public interface
elif
isinstance
(
func
,
PrimFunc
):
raise
ValueError
(
"Use tilelang.jit to decorate prim_func is not supported yet."
)
else
:
# Case 2: Used as @autotune(...) to configure, or func_or_out_idx is meant as out_idx.
# Create a _AutoTunerImplementation instance with the provided/defaulted arguments.
# This instance is a decorator that will be applied to the function later.
configured_decorator
=
_AutoTunerImplementation
(
def
decorator
(
impl
):
assert
isinstance
(
impl
,
JITImpl
),
"The @autotune decorator can only be applied to @tilelang.jit decorated instances."
return
AutoTuneImpl
(
jit_impl
=
impl
,
configs
=
configs
,
warmup
=
warmup
,
rep
=
rep
,
...
...
@@ -813,4 +752,5 @@ def autotune( # This is the new public interface
manual_check_prog
=
manual_check_prog
,
cache_input_tensors
=
cache_input_tensors
,
)
return
configured_decorator
return
decorator
tilelang/carver/roller/policy/default.py
View file @
bbbf4207
...
...
@@ -3,7 +3,7 @@ from __future__ import annotations
import
functools
import
math
from
queue
import
PriorityQueue
from
typing
import
Iterable
from
collections.abc
import
Iterable
import
numpy
as
np
import
tvm
...
...
tilelang/carver/roller/shape_inference/tir.py
View file @
bbbf4207
from
__future__
import
annotations
from
typing
import
Mapping
from
collections.abc
import
Mapping
from
tvm.tir.schedule.schedule
import
BlockRV
from
tvm.ir
import
structural_equal
from
tvm
import
arith
,
tir
...
...
tilelang/contrib/cc.py
View file @
bbbf4207
...
...
@@ -64,7 +64,7 @@ def get_cc():
return
None
@
functools
.
lru_
cache
(
maxsize
=
None
)
@
functools
.
cache
def
get_cplus_compiler
():
"""Return the path to the default C/C++ compiler.
...
...
tilelang/contrib/dlpack.py
View file @
bbbf4207
...
...
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
"""Wrapping functions to bridge frameworks with DLPack support to TVM"""
from
tvm
.runtime
import
ndarray
from
tvm
import
runtime
def
convert_func
(
tvm_func
,
tensor_type
,
to_dlpack_func
):
...
...
@@ -49,9 +49,9 @@ def convert_func(tvm_func, tensor_type, to_dlpack_func):
torch
.
float8_e4m3fn
,
torch
.
float8_e4m3fnuz
,
torch
.
float8_e5m2
,
torch
.
float8_e5m2fnuz
}:
return
ndarray
.
from_dlpack
(
to_dlpack_func
(
arg
.
view
(
torch
.
int8
))).
_create_view
(
return
runtime
.
from_dlpack
(
to_dlpack_func
(
arg
.
view
(
torch
.
int8
))).
_create_view
(
arg
.
shape
,
dtype
=
float8_dtype_map
[
arg
.
dtype
])
return
ndarray
.
from_dlpack
(
to_dlpack_func
(
arg
))
return
runtime
.
from_dlpack
(
to_dlpack_func
(
arg
))
return
arg
def
_wrapper
(
*
args
):
...
...
tilelang/contrib/hipcc.py
View file @
bbbf4207
...
...
@@ -9,7 +9,7 @@ from __future__ import absolute_import as _abs
import
subprocess
import
tvm
.
ffi
import
tvm
_
ffi
from
tvm.contrib
import
utils
from
tvm.base
import
py_str
...
...
@@ -97,7 +97,7 @@ def compile_hip(code,
return
data
@
tvm
.
ffi
.
register_func
(
"tilelang_callback_hip_compile"
,
override
=
True
)
@
tvm
_
ffi
.
register_
global_
func
(
"tilelang_callback_hip_compile"
,
override
=
True
)
def
tilelang_callback_hip_compile
(
code
,
target
):
"""use hipcc to generate fatbin code for better optimization"""
hsaco
=
compile_hip
(
code
,
target_format
=
"hsaco"
)
...
...
tilelang/contrib/nvcc.py
View file @
bbbf4207
...
...
@@ -7,9 +7,12 @@ from __future__ import annotations
import
os
import
subprocess
import
warnings
from
tilelang.env
import
CUDA_HOME
import
tvm.ffi
import
contextlib
from
tilelang.env
import
CUDA_HOME
,
CUTLASS_INCLUDE_DIR
,
TILELANG_TEMPLATE_PATH
import
shutil
import
tempfile
import
tvm_ffi
from
tilelang
import
tvm
as
tvm
from
tvm.target
import
Target
from
tvm.base
import
py_str
...
...
@@ -125,6 +128,154 @@ def compile_cuda(code,
return
data
def
default_compile_options
(
compile_flags
:
list
[
str
]
|
None
=
None
)
->
list
[
str
]:
"""
Build a set of default NVCC compile options for TileLang generated sources.
Includes C++ standard and common include paths (TileLang templates, CUTLASS,
CUDA include). Merges user-provided compile flags if given.
Parameters
----------
compile_flags : Optional[List[str]]
Additional flags to include. Items are split on whitespace.
Returns
-------
List[str]
A list of flags suitable for NVCC's command line.
"""
options
:
list
[
str
]
=
[
"-std=c++17"
]
try
:
if
TILELANG_TEMPLATE_PATH
:
options
.
append
(
f
"-I
{
TILELANG_TEMPLATE_PATH
}
"
)
except
Exception
:
pass
try
:
if
CUTLASS_INCLUDE_DIR
:
options
.
append
(
f
"-I
{
CUTLASS_INCLUDE_DIR
}
"
)
except
Exception
:
pass
try
:
if
CUDA_HOME
:
options
.
append
(
f
"-I
{
os
.
path
.
join
(
CUDA_HOME
,
'include'
)
}
"
)
except
Exception
:
pass
# Preserve user flags exactly, including repeated tokens required by NVCC
# (e.g., multiple "-gencode" pairs or repeated "-Xcompiler" entries).
if
compile_flags
:
import
shlex
for
flag
in
compile_flags
:
# Split each string like a shell would, preserving quoted args
tokens
=
shlex
.
split
(
flag
)
if
isinstance
(
flag
,
str
)
else
[
str
(
flag
)]
options
.
extend
(
tokens
)
return
options
def
get_ptx_from_source
(
code
:
str
,
compile_flags
:
list
[
str
]
|
None
=
None
,
verbose
:
bool
=
False
)
->
str
:
"""
Compile CUDA C++ source to PTX using NVCC and return as text.
Parameters
----------
code : str
CUDA C++ kernel source code.
compile_flags : Optional[List[str]]
Additional flags merged with defaults.
verbose : bool
Print NVCC output when True.
Returns
-------
str
PTX text.
"""
opts
=
default_compile_options
(
compile_flags
)
ptx_bytes
=
compile_cuda
(
code
,
target_format
=
"ptx"
,
options
=
opts
,
verbose
=
verbose
)
try
:
return
ptx_bytes
.
decode
(
"utf-8"
)
except
Exception
:
return
str
(
ptx_bytes
)
def
_find_tool
(
name
:
str
)
->
str
|
None
:
"""Find a CUDA binary in PATH or under CUDA_HOME/bin."""
path
=
shutil
.
which
(
name
)
if
path
:
return
path
if
CUDA_HOME
:
candidate
=
os
.
path
.
join
(
CUDA_HOME
,
"bin"
,
name
)
if
os
.
path
.
exists
(
candidate
):
return
candidate
return
None
def
get_sass_from_source
(
code
:
str
,
compile_flags
:
list
[
str
]
|
None
=
None
,
verbose
:
bool
=
False
)
->
str
:
"""
Compile CUDA C++ source to CUBIN and disassemble to SASS.
Uses nvdisasm if available; otherwise falls back to cuobjdump.
Parameters
----------
code : str
CUDA C++ kernel source code.
compile_flags : Optional[List[str]]
Additional flags merged with defaults.
verbose : bool
Print tool outputs when True.
Returns
-------
str
SASS text.
"""
opts
=
default_compile_options
(
compile_flags
)
cubin_bytes
=
compile_cuda
(
code
,
target_format
=
"cubin"
,
options
=
opts
,
verbose
=
verbose
)
# Write to a temp .cubin file
with
tempfile
.
NamedTemporaryFile
(
suffix
=
".cubin"
,
delete
=
False
)
as
tmp
:
tmp
.
write
(
cubin_bytes
)
cubin_path
=
tmp
.
name
# Try disassembly tools (prefer nvdisasm, fallback cuobjdump)
cand_nvdisasm
=
_find_tool
(
"nvdisasm"
)
cand_cuobjdump
=
_find_tool
(
"cuobjdump"
)
if
not
cand_nvdisasm
and
not
cand_cuobjdump
:
raise
RuntimeError
(
"Cannot find 'nvdisasm' or 'cuobjdump'. Please ensure CUDA toolkit is installed and in PATH."
)
last_err
:
str
|
None
=
None
try
:
# Attempt nvdisasm first
tools_to_try
=
[]
if
cand_nvdisasm
:
tools_to_try
.
append
((
"nvdisasm"
,
[
cand_nvdisasm
,
cubin_path
]))
if
cand_cuobjdump
:
tools_to_try
.
append
((
"cuobjdump"
,
[
cand_cuobjdump
,
"--dump-sass"
,
cubin_path
]))
for
tool_name
,
cmd
in
tools_to_try
:
proc
=
subprocess
.
Popen
(
cmd
,
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
STDOUT
)
out
,
_
=
proc
.
communicate
()
text
=
py_str
(
out
)
if
verbose
:
print
(
f
"[
{
tool_name
}
] output:
\n
{
text
}
"
)
if
proc
.
returncode
==
0
and
text
.
strip
():
return
text
last_err
=
f
"
{
tool_name
}
rc=
{
proc
.
returncode
}
, output:
\n
{
text
}
"
# If we reach here, all attempts failed
raise
RuntimeError
(
f
"SASS disassembly failed. Tried tools: "
f
"
{
', '
.
join
(
name
for
name
,
_
in
tools_to_try
)
}
\n
{
last_err
or
''
}
"
)
finally
:
with
contextlib
.
suppress
(
Exception
):
os
.
remove
(
cubin_path
)
def
find_cuda_path
():
"""Utility function to find cuda path
...
...
@@ -182,14 +333,14 @@ def get_cuda_version(cuda_path=None):
raise
RuntimeError
(
"Cannot read cuda version file"
)
@
tvm
.
ffi
.
register_func
(
"tilelang_callback_cuda_compile"
,
override
=
True
)
@
tvm
_
ffi
.
register_
global_
func
(
"tilelang_callback_cuda_compile"
,
override
=
True
)
def
tilelang_callback_cuda_compile
(
code
,
target
):
# pylint: disable=unused-argument
"""use nvcc to generate fatbin code for better optimization"""
ptx
=
compile_cuda
(
code
,
target_format
=
"fatbin"
)
return
ptx
@
tvm
.
ffi
.
register_func
(
"tilelang_callback_libdevice_path"
,
override
=
True
)
@
tvm
_
ffi
.
register_
global_
func
(
"tilelang_callback_libdevice_path"
,
override
=
True
)
def
find_libdevice_path
(
arch
):
"""Utility function to find libdevice
...
...
@@ -254,7 +405,7 @@ def callback_libdevice_path(arch):
return
""
@
tvm
.
ffi
.
register_func
(
"tvm.contrib.nvcc.get_compute_version"
,
override
=
True
)
@
tvm
_
ffi
.
register_
global_
func
(
"tvm.contrib.nvcc.get_compute_version"
,
override
=
True
)
def
get_target_compute_version
(
target
=
None
):
"""Utility function to get compute capability of compilation target.
...
...
@@ -400,7 +551,7 @@ def have_cudagraph():
return
False
@
tvm
.
ffi
.
register_func
(
"tvm.contrib.nvcc.supports_bf16"
,
override
=
True
)
@
tvm
_
ffi
.
register_
global_
func
(
"tvm.contrib.nvcc.supports_bf16"
,
override
=
True
)
def
have_bf16
(
compute_version
):
"""Either bf16 support is provided in the compute capability or not
...
...
@@ -413,7 +564,7 @@ def have_bf16(compute_version):
return
major
>=
8
@
tvm
.
ffi
.
register_func
(
"tvm.contrib.nvcc.supports_fp8"
,
override
=
True
)
@
tvm
_
ffi
.
register_
global_
func
(
"tvm.contrib.nvcc.supports_fp8"
,
override
=
True
)
def
have_fp8
(
compute_version
):
"""Whether fp8 support is provided in the specified compute capability or not
...
...
@@ -430,7 +581,7 @@ def have_fp8(compute_version):
return
any
(
conditions
)
@
tvm
.
ffi
.
register_func
(
"tvm.contrib.nvcc.supports_tma"
,
override
=
True
)
@
tvm
_
ffi
.
register_
global_
func
(
"tvm.contrib.nvcc.supports_tma"
,
override
=
True
)
def
have_tma
(
target
):
"""Whether TMA support is provided in the specified compute capability or not
...
...
tilelang/contrib/rocm.py
View file @
bbbf4207
...
...
@@ -21,7 +21,7 @@ import subprocess
import
os
from
os.path
import
join
,
exists
import
tvm
.
ffi
import
tvm
_
ffi
from
tvm.base
import
py_str
import
tvm.runtime
import
tvm.target
...
...
@@ -100,7 +100,7 @@ def rocm_link(in_file, out_file, lld=None):
raise
RuntimeError
(
msg
)
@
tvm
.
ffi
.
register_func
(
"tvm_callback_rocm_link"
,
override
=
True
)
@
tvm
_
ffi
.
register_
global_
func
(
"tvm_callback_rocm_link"
,
override
=
True
)
def
callback_rocm_link
(
obj_bin
):
"""Links object file generated from LLVM to HSA Code Object
...
...
@@ -124,7 +124,7 @@ def callback_rocm_link(obj_bin):
return
cobj_bin
@
tvm
.
ffi
.
register_func
(
"tvm_callback_rocm_bitcode_path"
,
override
=
True
)
@
tvm
_
ffi
.
register_
global_
func
(
"tvm_callback_rocm_bitcode_path"
,
override
=
True
)
def
callback_rocm_bitcode_path
(
rocdl_dir
=
None
):
"""Utility function to find ROCm device library bitcodes
...
...
@@ -226,8 +226,11 @@ def have_matrixcore(compute_version=None):
return
False
@
tvm
.
ffi
.
register_func
(
"tvm_callback_rocm_get_arch"
,
override
=
True
)
def
get_rocm_arch
(
rocm_path
=
"/opt/dtk"
):
@
tvm_ffi
.
register_global_func
(
"tvm_callback_rocm_get_arch"
,
override
=
True
)
def
get_rocm_arch
(
rocm_path
=
"/opt/rocm"
):
# @tvm.ffi.register_func("tvm_callback_rocm_get_arch", override=True)
# def get_rocm_arch(rocm_path="/opt/dtk"):
"""Utility function to get the AMD GPU architecture
Parameters
...
...
tilelang/engine/callback.py
View file @
bbbf4207
from
__future__
import
annotations
from
typing
import
Callable
from
tvm
import
register_func
import
tvm_ffi
from
tvm.target
import
Target
...
...
@@ -12,7 +12,7 @@ def register_cuda_postproc(func: Callable[[str, Target], str], override: bool =
and returns the processed code (str).
override: Whether to override existing registered function. Defaults to True.
"""
register_func
(
"tilelang_callback_cuda_postproc"
,
f
=
func
,
override
=
override
)
tvm_ffi
.
register_
global_
func
(
"tilelang_callback_cuda_postproc"
,
f
=
func
,
override
=
override
)
def
register_hip_postproc
(
func
:
Callable
[[
str
,
Target
],
str
],
override
:
bool
=
True
):
...
...
@@ -23,7 +23,7 @@ def register_hip_postproc(func: Callable[[str, Target], str], override: bool = T
and returns the processed code (str).
override: Whether to override existing registered function. Defaults to True.
"""
register_func
(
"tilelang_callback_hip_postproc"
,
f
=
func
,
override
=
override
)
tvm_ffi
.
register_
global_
func
(
"tilelang_callback_hip_postproc"
,
f
=
func
,
override
=
override
)
def
register_cuda_postproc_callback
(
func
:
Callable
|
bool
=
None
,
override
:
bool
=
True
):
...
...
tilelang/engine/lower.py
View file @
bbbf4207
...
...
@@ -7,6 +7,7 @@ from typing import Callable
import
tilelang.transform
from
tilelang
import
tvm
as
tvm
from
tvm
import
tir
import
tvm_ffi
from
tvm.ir
import
CallingConv
from
tvm.target
import
Target
from
tilelang.contrib
import
hipcc
,
nvcc
...
...
@@ -52,7 +53,7 @@ def get_host_call(is_device_c: bool = False) -> Callable[[tir.PrimFunc], bool]:
return
lambda
func
:
not
get_device_call
(
is_device_c
)(
func
)
@
tvm
.
register_func
(
"tilelang_callback_cuda_compile"
,
override
=
True
)
@
tvm
_ffi
.
register_
global_
func
(
"tilelang_callback_cuda_compile"
,
override
=
True
)
def
tilelang_callback_cuda_compile
(
code
,
target
):
project_root
=
osp
.
join
(
osp
.
dirname
(
__file__
),
"../.."
)
if
"TL_TEMPLATE_PATH"
in
os
.
environ
:
...
...
@@ -89,7 +90,7 @@ def tilelang_callback_cuda_compile(code, target):
return
ptx
@
tvm
.
register_func
(
"tilelang_callback_hip_compile"
,
override
=
True
)
@
tvm
_ffi
.
register_
global_
func
(
"tilelang_callback_hip_compile"
,
override
=
True
)
def
tilelang_callback_hip_compile
(
code
,
target
):
project_root
=
osp
.
join
(
osp
.
dirname
(
__file__
),
"../.."
)
tl_template_path
=
osp
.
abspath
(
osp
.
join
(
project_root
,
"src"
))
...
...
@@ -182,7 +183,7 @@ def device_codegen_without_compile(device_mod: tvm.IRModule, target: Target) ->
elif
target
.
kind
.
name
==
"llvm"
:
device_mod
=
tvm
.
ffi
.
get_global_func
(
"target.build.llvm"
)(
device_mod
,
target
)
elif
target
.
kind
.
name
==
"webgpu"
:
device_mod
=
tvm
.
ffi
.
get_global_func
(
"target.build.
tilelang_
webgpu"
)(
device_mod
,
target
)
device_mod
=
tvm
.
ffi
.
get_global_func
(
"target.build.webgpu"
)(
device_mod
,
target
)
elif
target
.
kind
.
name
==
"metal"
:
device_mod
=
tvm
.
ffi
.
get_global_func
(
"target.build.metal"
)(
device_mod
,
target
)
else
:
...
...
@@ -241,6 +242,6 @@ def lower(
host_mod
=
host_codegen
(
host_mod
,
target_host
)
host_mod
.
import_module
(
codegen_mod
)
return
CompiledArtifact
(
host_mod
,
device_mod
,
params
,
codegen_mod
.
ge
t_source
(),
rt_mod
=
host_mod
)
host_mod
,
device_mod
,
params
,
codegen_mod
.
inspec
t_source
(),
rt_mod
=
host_mod
)
return
CompiledArtifact
(
host_mod
,
device_mod
,
params
,
codegen_mod
.
ge
t_source
())
return
CompiledArtifact
(
host_mod
,
device_mod
,
params
,
codegen_mod
.
inspec
t_source
())
tilelang/engine/phase.py
View file @
bbbf4207
...
...
@@ -96,6 +96,8 @@ def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule:
mod
=
tilelang
.
transform
.
LetInline
()(
mod
)
# Add wrapper for single buf store
mod
=
tilelang
.
transform
.
AddWrapperForSingleBufStore
()(
mod
)
# Normalize negative indices to canonical non-negative form
mod
=
tilelang
.
transform
.
LegalizeNegativeIndex
()(
mod
)
# Inject assumes to speedup tvm prover
mod
=
tilelang
.
transform
.
InjectAssumes
()(
mod
)
# Simplify the IR expressions
...
...
@@ -118,8 +120,6 @@ def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule:
# TODO(lei): return to tir pass when kSymbolicBound simplification
# is merged into tvm.
mod
=
tilelang
.
transform
.
Simplify
()(
mod
)
# Try to vectorize loop with dynamic shape
mod
=
tilelang
.
transform
.
LoopVectorizeDynamic
()(
mod
)
return
mod
...
...
tilelang/env.py
View file @
bbbf4207
...
...
@@ -236,6 +236,10 @@ class Environment:
"1"
)
# print kernel name on compile
TILELANG_CLEAR_CACHE
=
EnvVar
(
"TILELANG_CLEAR_CACHE"
,
"0"
)
# clear cache automatically if set
# Kernel selection options
# Default to GEMM v2; set to "1"/"true"/"yes"/"on" to force v1
TILELANG_USE_GEMM_V1
=
EnvVar
(
"TILELANG_USE_GEMM_V1"
,
"0"
)
# Auto-tuning settings
TILELANG_AUTO_TUNING_CPU_UTILITIES
=
EnvVar
(
"TILELANG_AUTO_TUNING_CPU_UTILITIES"
,
"0.9"
)
# percent of CPUs used
...
...
@@ -274,6 +278,14 @@ class Environment:
def
is_print_on_compilation_enabled
(
self
)
->
bool
:
return
self
.
TILELANG_PRINT_ON_COMPILATION
.
lower
()
in
(
"1"
,
"true"
,
"yes"
,
"on"
)
def
use_gemm_v1
(
self
)
->
bool
:
"""Return True if GEMM v1 should be used based on env.
Controlled by `TILELANG_USE_GEMM_V1`. Truthy values are one of
{"1", "true", "yes", "on"} (case-insensitive).
"""
return
str
(
self
.
TILELANG_USE_GEMM_V1
).
lower
()
in
(
"1"
,
"true"
,
"yes"
,
"on"
)
# Instantiate as a global configuration object
env
=
Environment
()
...
...
@@ -297,12 +309,11 @@ def prepend_pythonpath(path):
if
env
.
TVM_IMPORT_PYTHON_PATH
is
not
None
:
prepend_pythonpath
(
env
.
TVM_IMPORT_PYTHON_PATH
)
else
:
tvm_path
=
os
.
path
.
join
(
THIRD_PARTY_ROOT
,
"
tvm
"
)
tvm_path
=
os
.
path
.
join
(
THIRD_PARTY_ROOT
,
'
tvm
'
,
'python'
)
assert
os
.
path
.
exists
(
tvm_path
),
tvm_path
if
tvm_path
not
in
sys
.
path
:
tvm_python_binding
=
os
.
path
.
join
(
tvm_path
,
'python'
)
prepend_pythonpath
(
tvm_python_binding
)
env
.
TVM_IMPORT_PYTHON_PATH
=
tvm_python_binding
prepend_pythonpath
(
tvm_path
)
env
.
TVM_IMPORT_PYTHON_PATH
=
tvm_path
if
os
.
environ
.
get
(
"TVM_LIBRARY_PATH"
)
is
None
:
os
.
environ
[
'TVM_LIBRARY_PATH'
]
=
env
.
TVM_LIBRARY_PATH
=
os
.
pathsep
.
join
(
TL_LIBS
)
...
...
tilelang/intrinsics/mfma_macro_generator.py
View file @
bbbf4207
...
...
@@ -2,10 +2,32 @@ from __future__ import annotations
from
tilelang
import
tvm
as
tvm
import
tilelang.language
as
T
from
tvm
import
DataType
from
tvm.tir
import
PrimExpr
from
tvm.tir
import
PrimExpr
,
IndexMap
,
Buffer
,
Var
,
BufferRegion
from
tvm.runtime
import
convert
from
.utils
import
(
mfma_store_index_map
,)
from
typing
import
Literal
,
Callable
from
tilelang.utils
import
is_fragment
from
tilelang.utils.language
import
to_buffer_region
from
.mfma_layout
import
(
shared_16x4_to_local_64x1_layout_A
,
shared_4x16_to_local_64x1_layout_B
,
shared_16x16_to_local_64x4_layout_A
,
shared_16x16_to_local_64x4_layout_B
,
shared_16x32_to_local_64x8_layout_A
,
shared_16x32_to_local_64x8_layout_B
,
shared_16x64_to_local_64x16_layout_A
,
shared_16x64_to_local_64x16_layout_B
,
thread_id_shared_access_64x1_to_16x4_layout_A
,
thread_id_shared_access_64x1_to_4x16_layout_B
,
thread_id_shared_access_64x4_to_16x16_layout_A
,
thread_id_shared_access_64x4_to_16x16_layout_B
,
thread_id_shared_access_64x8_to_16x32_layout_A
,
thread_id_shared_access_64x8_to_16x32_layout_B
,
thread_id_shared_access_64x16_to_16x64_layout_A
,
thread_id_shared_access_64x16_to_16x64_layout_B
,
)
lift
=
convert
...
...
@@ -53,6 +75,7 @@ class MatrixCoreIntrinEmitter:
k_pack
:
int
|
None
=
None
,
is_m_first
:
bool
|
None
=
False
,
b_preshuffle
:
bool
|
None
=
False
,
thread_var
:
Var
|
None
=
None
,
):
self
.
a_dtype
=
a_dtype
self
.
b_dtype
=
b_dtype
...
...
@@ -79,6 +102,7 @@ class MatrixCoreIntrinEmitter:
self
.
reduce_k
=
reduce_k
self
.
threads
=
(
self
.
WARP_SIZE
*
(
block_row_warps
*
block_col_warps
)
*
reduce_k
)
self
.
num_elems_per_byte
=
num_elems_per_byte
self
.
thread_var
=
thread_var
def
_initialize_k_dim
(
self
,
a_dtype
=
"float16"
):
if
isinstance
(
a_dtype
,
str
):
...
...
@@ -115,6 +139,7 @@ class MatrixCoreIntrinEmitter:
}[
out_dtype
]
in_dtype_abbrv
=
{
"bfloat16"
:
"bf16"
,
"float16"
:
"f16"
,
"float32"
:
"f32"
,
"int8"
:
"i8"
,
...
...
@@ -126,6 +151,9 @@ class MatrixCoreIntrinEmitter:
self
.
mfma_suffix
=
f
"
{
out_dtype_abbrv
}
_
{
M_DIM
}
x
{
N_DIM
}
x
{
k_dim
}
_fp8_fp8"
elif
in_dtype_abbrv
==
"i8"
:
self
.
mfma_suffix
=
f
"
{
out_dtype_abbrv
}
_
{
M_DIM
}
x
{
N_DIM
}
x
{
k_dim
}
_i8"
elif
in_dtype_abbrv
==
"bf16"
:
# HIP intrinsic uses ...x{K}bf16_1k without an underscore before bf16
self
.
mfma_suffix
=
f
"
{
out_dtype_abbrv
}
_
{
M_DIM
}
x
{
N_DIM
}
x
{
k_dim
}
bf16_1k"
else
:
self
.
mfma_suffix
=
f
"
{
out_dtype_abbrv
}
_
{
M_DIM
}
x
{
N_DIM
}
x
{
k_dim
}{
in_dtype_abbrv
}
"
...
...
@@ -147,24 +175,6 @@ class MatrixCoreIntrinEmitter:
self
.
b_preshuffle
=
b_preshuffle
def
get_ldmatrix_index_map
(
self
,
is_b
=
False
):
from
.mfma_layout
import
(
shared_16x4_to_local_64x1_layout_A
,
shared_4x16_to_local_64x1_layout_B
,
shared_16x16_to_local_64x4_layout_A
,
shared_16x16_to_local_64x4_layout_B
,
shared_16x32_to_local_64x8_layout_A
,
shared_16x32_to_local_64x8_layout_B
,
shared_16x64_to_local_64x16_layout_A
,
shared_16x64_to_local_64x16_layout_B
,
thread_id_shared_access_64x1_to_16x4_layout_A
,
thread_id_shared_access_64x1_to_4x16_layout_B
,
thread_id_shared_access_64x4_to_16x16_layout_A
,
thread_id_shared_access_64x4_to_16x16_layout_B
,
thread_id_shared_access_64x8_to_16x32_layout_A
,
thread_id_shared_access_64x8_to_16x32_layout_B
,
thread_id_shared_access_64x16_to_16x64_layout_A
,
thread_id_shared_access_64x16_to_16x64_layout_B
,
)
k_dim
=
self
.
k_dim
*
self
.
k_pack
transposed
=
self
.
a_transposed
if
not
is_b
else
self
.
b_transposed
...
...
@@ -200,6 +210,22 @@ class MatrixCoreIntrinEmitter:
return
index_map
,
reverse_index_map
def
get_store_index_map
(
self
,
inverse
:
bool
=
False
)
->
IndexMap
:
warp_size
,
local_size_c
=
self
.
WARP_SIZE
,
self
.
local_size_out
index_map
=
IndexMap
.
from_func
(
mfma_store_index_map
,
index_dtype
=
"int32"
)
if
not
inverse
:
return
index_map
inverse_index_map
=
index_map
.
inverse
([
warp_size
,
local_size_c
])
return
inverse_index_map
def
get_thread_binding
(
self
):
if
self
.
thread_var
is
None
:
current_frame
=
T
.
KernelLaunchFrame
.
Current
()
assert
current_frame
is
not
None
,
"Must be called in a T.Kernel Frame"
return
current_frame
.
get_thread_binding
()
else
:
return
self
.
thread_var
def
extract_thread_binding
(
self
,
thread_id
,
is_m_first
=
None
)
->
tuple
[
PrimExpr
,
PrimExpr
,
PrimExpr
]:
...
...
@@ -229,7 +255,7 @@ class MatrixCoreIntrinEmitter:
(
WARP_SIZE
*
block_row_warps
))
%
block_col_warps
,
return
lane_id
,
warp_n
,
warp_m
def
ldmatrix_a
(
self
,
A_local_buf
,
A_shared_buf
,
ki
,
rk
=
0
):
def
ldmatrix_a
(
self
,
A_local_buf
,
A_shared_buf
:
Buffer
|
BufferRegion
,
ki
,
rk
=
0
):
warp_row_tiles
=
self
.
warp_row_tiles
warp_rows
=
self
.
warp_rows
chunk
=
self
.
chunk
...
...
@@ -238,10 +264,15 @@ class MatrixCoreIntrinEmitter:
local_size_a
=
self
.
local_size_a
k_pack
=
self
.
k_pack
is_transposed
=
self
.
a_transposed
current_frame
=
T
.
KernelLaunchFrame
.
Current
()
thread_binding
=
current_frame
.
get_thread_binding
()
thread_binding
=
self
.
get_thread_binding
()
_
,
reverse_index_map
=
self
.
get_ldmatrix_index_map
(
is_b
=
False
)
# legalize shared buffer to region
A_region
=
to_buffer_region
(
A_shared_buf
)
A_buf
=
A_region
.
buffer
A_base0
=
A_region
.
region
[
-
2
].
min
A_base1
=
A_region
.
region
[
-
1
].
min
@
T
.
macro
def
_warp_ldmatrix_a
(
A_local_buf
,
...
...
@@ -257,20 +288,20 @@ class MatrixCoreIntrinEmitter:
row
,
col
=
T
.
meta_var
(
reverse_index_map
(
tx
,
local_id
))
l
,
r
=
(
rk
*
chunk
+
ki
*
(
k_pack
*
micro_size_k
),
warp_m
*
warp_row_tiles
+
i
*
micro_size_x
)
A_local_buf
[
i
*
k_pack
*
local_size_a
+
local_id
]
=
A_
shared_buf
[
l
+
row
,
r
+
col
]
A_local_buf
[
i
*
k_pack
*
local_size_a
+
local_id
]
=
A_
buf
[
A_base0
+
l
+
row
,
A_base1
+
r
+
col
]
else
:
for
i
in
T
.
serial
(
warp_rows
):
for
local_id
in
T
.
vectorized
(
k_pack
*
local_size_a
):
row
,
col
=
T
.
meta_var
(
reverse_index_map
(
tx
,
local_id
))
l
,
r
=
(
warp_m
*
warp_row_tiles
+
i
*
micro_size_x
,
rk
*
chunk
+
ki
*
(
k_pack
*
micro_size_k
))
A_local_buf
[
i
*
k_pack
*
local_size_a
+
local_id
]
=
A_
shared_buf
[
l
+
row
,
r
+
col
]
A_local_buf
[
i
*
k_pack
*
local_size_a
+
local_id
]
=
A_
buf
[
A_base0
+
l
+
row
,
A_base1
+
r
+
col
]
return
_warp_ldmatrix_a
(
A_local_buf
,
A_shared_buf
,
ki
,
thread_binding
,
rk
)
def
ldmatrix_b
(
self
,
B_local_buf
,
B_shared_buf
,
ki
,
rk
=
0
):
def
ldmatrix_b
(
self
,
B_local_buf
,
B_shared_buf
:
Buffer
|
BufferRegion
,
ki
,
rk
=
0
):
warp_col_tiles
=
self
.
warp_col_tiles
warp_cols
=
self
.
warp_cols
chunk
=
self
.
chunk
...
...
@@ -279,10 +310,15 @@ class MatrixCoreIntrinEmitter:
local_size_b
=
self
.
local_size_b
k_pack
=
self
.
k_pack
is_transposed
=
self
.
b_transposed
current_frame
=
T
.
KernelLaunchFrame
.
Current
()
thread_binding
=
current_frame
.
get_thread_binding
()
thread_binding
=
self
.
get_thread_binding
()
_
,
reverse_index_map
=
self
.
get_ldmatrix_index_map
(
is_b
=
True
)
# legalize shared buffer to region
B_region
=
to_buffer_region
(
B_shared_buf
)
B_buf
=
B_region
.
buffer
B_base0
=
B_region
.
region
[
-
2
].
min
B_base1
=
B_region
.
region
[
-
1
].
min
@
T
.
macro
def
_warp_ldmatrix_b
(
B_local_buf
,
...
...
@@ -300,8 +336,8 @@ class MatrixCoreIntrinEmitter:
warp_n
*
warp_col_tiles
+
j
*
micro_size_y
,
rk
*
chunk
+
ki
*
(
k_pack
*
micro_size_k
),
)
B_local_buf
[
j
*
k_pack
*
local_size_b
+
local_id
]
=
B_
shared_buf
[
l
+
row
,
r
+
col
]
B_local_buf
[
j
*
k_pack
*
local_size_b
+
local_id
]
=
B_
buf
[
B_base0
+
l
+
row
,
B_base1
+
r
+
col
]
else
:
for
j
in
T
.
serial
(
warp_cols
):
...
...
@@ -311,12 +347,16 @@ class MatrixCoreIntrinEmitter:
rk
*
chunk
+
ki
*
(
k_pack
*
micro_size_k
),
warp_n
*
warp_col_tiles
+
j
*
micro_size_y
,
)
B_local_buf
[
j
*
k_pack
*
local_size_b
+
local_id
]
=
B_
shared_buf
[
l
+
row
,
r
+
col
]
B_local_buf
[
j
*
k_pack
*
local_size_b
+
local_id
]
=
B_
buf
[
B_base0
+
l
+
row
,
B_base1
+
r
+
col
]
return
_warp_ldmatrix_b
(
B_local_buf
,
B_shared_buf
,
ki
,
thread_binding
,
rk
)
def
mfma
(
self
,
A_local_buf
,
B_local_buf
,
C_local_buf
):
def
mfma
(
self
,
A_local_buf
:
Buffer
,
B_local_buf
:
Buffer
,
C_local_buf
:
Buffer
,
k_inner
:
PrimExpr
|
None
=
0
):
warp_rows
=
self
.
warp_rows
warp_cols
=
self
.
warp_cols
local_size_a
=
self
.
local_size_a
...
...
@@ -329,8 +369,13 @@ class MatrixCoreIntrinEmitter:
compute_b_dtype
=
b_dtype
if
local_size_b
==
1
else
f
"
{
b_dtype
}
x
{
local_size_b
}
"
compute_out_dtype
=
out_dtype
if
local_size_out
==
1
else
f
"
{
out_dtype
}
x
{
local_size_out
}
"
a_is_fragment
=
is_fragment
(
A_local_buf
)
b_is_fragment
=
is_fragment
(
B_local_buf
)
a_local_stride
:
PrimExpr
=
k_inner
*
warp_rows
*
local_size_a
if
a_is_fragment
else
0
b_local_stride
:
PrimExpr
=
k_inner
*
warp_cols
*
local_size_b
if
b_is_fragment
else
0
@
T
.
macro
def
_warp_mma
(
A_local_buf
,
B_local_buf
,
C_local_buf
):
def
_warp_m
f
ma
(
A_local_buf
,
B_local_buf
,
C_local_buf
):
for
kp
,
i
,
j
in
T
.
grid
(
k_pack
,
warp_rows
,
warp_cols
):
T
.
tvm_mfma
(
mfma_suffix
,
...
...
@@ -340,15 +385,15 @@ class MatrixCoreIntrinEmitter:
compute_b_dtype
,
compute_out_dtype
,
B_local_buf
.
data
,
((
j
*
k_pack
+
kp
)
*
local_size_b
)
//
local_size_b
,
(
b_local_stride
+
(
j
*
k_pack
+
kp
)
*
local_size_b
)
//
local_size_b
,
A_local_buf
.
data
,
((
i
*
k_pack
+
kp
)
*
local_size_a
)
//
local_size_a
,
(
a_local_stride
+
(
i
*
k_pack
+
kp
)
*
local_size_a
)
//
local_size_a
,
C_local_buf
.
data
,
(
i
*
warp_cols
*
local_size_out
+
j
*
local_size_out
)
//
local_size_out
,
dtype
=
compute_out_dtype
,
)
return
_warp_mma
(
A_local_buf
,
B_local_buf
,
C_local_buf
)
return
_warp_m
f
ma
(
A_local_buf
,
B_local_buf
,
C_local_buf
)
def
stmatrix
(
self
,
C_local_buf
,
C_buf
,
pid_m
=
None
,
pid_n
=
None
):
block_row_warps
=
self
.
block_row_warps
...
...
@@ -356,8 +401,7 @@ class MatrixCoreIntrinEmitter:
warp_rows
=
self
.
warp_rows
warp_cols
=
self
.
warp_cols
local_size_out
=
self
.
local_size_out
current_frame
=
T
.
KernelLaunchFrame
.
Current
()
thread_binding
=
current_frame
.
get_thread_binding
()
thread_binding
=
self
.
get_thread_binding
()
is_global
=
pid_m
is
not
None
and
pid_n
is
not
None
BLOCK_M
=
block_row_warps
*
warp_rows
BLOCK_N
=
block_col_warps
*
warp_cols
...
...
@@ -366,7 +410,7 @@ class MatrixCoreIntrinEmitter:
assert
C_buf_dims
in
{
2
,
4
},
"C_buf should be 2D or 4D"
# STS
# MMA Store must be in simulated instead of TVM Intrins
# M
F
MA Store must be in simulated instead of TVM Intrins
# As TVM Intrins is like a hack that the threadIdx.x should be always
# equal to the warp_size
@
T
.
macro
...
...
@@ -400,6 +444,217 @@ class MatrixCoreIntrinEmitter:
thread_binding
)
if
is_global
else
_warp_stmatrix_shared
(
C_local_buf
,
C_buf
,
thread_binding
)
def
make_mfma_load_layout
(
self
,
local_buf
:
Buffer
,
matrix
:
Literal
[
"A"
,
"B"
]
=
"A"
)
->
T
.
Fragment
:
"""
Create a layout function for storing MFMA results into a fragment buffer.
Parameters
----------
local_buf : tir.Buffer
The local buffer representing a fragment of a matrix.
Returns
-------
T.Fragment
A fragment object that describes how threads and indices
in `local_buf` are laid out.
Raises
------
AssertionError
If `local_buf` is not detected to be a fragment buffer.
"""
from
tilelang.utils
import
is_fragment
assert
matrix
in
[
"A"
,
"B"
],
"matrix should be either A or B"
matrix_is_a
:
bool
=
matrix
==
"A"
matrix_is_b
:
bool
=
matrix
==
"B"
transposed
=
self
.
a_transposed
if
matrix_is_a
else
self
.
b_transposed
# s represents spatial axis
# r represents reduction axis
# sr represents the two dims are spatial + reduction
# rs represents the two dims are reduction + spatial
# sr also can represent a non-transposed basic layout
# then rs also can represent a transposed basic layout
transform_func_sr_a
:
Callable
=
None
transform_func_sr_b
:
Callable
=
None
k_dim
=
self
.
k_dim
*
self
.
k_pack
if
k_dim
==
4
:
transform_func_sr_a
=
shared_16x4_to_local_64x1_layout_A
transform_func_sr_b
=
shared_16x4_to_local_64x1_layout_A
elif
k_dim
==
16
:
transform_func_sr_a
=
shared_16x16_to_local_64x4_layout_A
transform_func_sr_b
=
shared_16x16_to_local_64x4_layout_A
elif
k_dim
==
32
:
transform_func_sr_a
=
shared_16x32_to_local_64x8_layout_A
transform_func_sr_b
=
shared_16x32_to_local_64x8_layout_A
elif
k_dim
==
64
:
transform_func_sr_a
=
shared_16x64_to_local_64x16_layout_A
transform_func_sr_b
=
shared_16x64_to_local_64x16_layout_A
else
:
raise
ValueError
(
"k_dim must be 4 or 16 or 32 or 64 currently"
)
is_sr_conditions
=
[
False
]
is_sr_conditions
.
append
(
matrix_is_a
and
not
transposed
)
is_sr_conditions
.
append
(
matrix_is_b
and
transposed
)
is_sr_axis_order
=
any
(
is_sr_conditions
)
transform_func
:
Callable
=
None
if
matrix_is_a
:
transform_func
=
transform_func_sr_a
if
is_sr_axis_order
else
lambda
i
,
j
:
transform_func_sr_a
(
j
,
i
)
elif
matrix_is_b
:
transform_func
=
transform_func_sr_b
if
is_sr_axis_order
else
lambda
i
,
j
:
transform_func_sr_b
(
j
,
i
)
else
:
raise
ValueError
(
f
"Unsupported matrix
{
matrix
}
"
)
assert
is_fragment
(
local_buf
),
f
"local_buf must be a fragment, but got
{
local_buf
.
scope
()
}
"
if
matrix_is_a
:
micro_size_s
,
micro_size_r
=
self
.
micro_size_x
,
self
.
micro_size_k
else
:
micro_size_r
,
micro_size_s
=
self
.
micro_size_k
,
self
.
micro_size_y
block_row_warps
,
block_col_warps
=
(
self
.
block_row_warps
,
self
.
block_col_warps
,
)
inverse_mfma_load_layout
=
IndexMap
.
from_func
(
transform_func
,
index_dtype
=
"int32"
)
def
forward_thread
(
i
:
int
,
j
:
int
)
->
int
:
"""
Given the row index `i` and column index `j` in the fragment,
"""
lane_id
,
_
=
inverse_mfma_load_layout
.
map_indices
([
i
,
j
])
return
lane_id
def
forward_index
(
i
:
int
,
j
:
int
)
->
int
:
"""
Given the row index `i` and column index `j` in the fragment,
"""
_
,
local_id
=
inverse_mfma_load_layout
.
map_indices
([
i
,
j
])
return
local_id
base_fragment
=
T
.
Fragment
(
[
micro_size_s
,
micro_size_r
]
if
is_sr_axis_order
else
[
micro_size_r
,
micro_size_s
],
forward_thread_fn
=
forward_thread
,
forward_index_fn
=
forward_index
,
)
warp_rows
,
warp_cols
=
self
.
warp_rows
,
self
.
warp_cols
chunk
=
self
.
chunk
warp_s
=
warp_rows
if
matrix_is_a
else
warp_cols
warp_r
=
chunk
//
micro_size_r
block_s
=
block_row_warps
if
matrix_is_a
else
block_col_warps
replicate
=
block_col_warps
if
matrix_is_a
else
block_row_warps
if
is_sr_axis_order
:
warp_fragment
=
base_fragment
.
repeat
([
warp_s
,
warp_r
],
repeat_on_thread
=
False
,
lower_dim_first
=
False
)
if
matrix_is_a
:
block_fragment
=
warp_fragment
.
repeat
([
block_s
,
1
],
repeat_on_thread
=
True
,
lower_dim_first
=
True
).
replicate
(
replicate
)
elif
matrix_is_b
:
block_fragment
=
warp_fragment
.
replicate
(
replicate
).
repeat
([
block_s
,
1
],
repeat_on_thread
=
True
,
lower_dim_first
=
True
)
else
:
raise
ValueError
(
f
"Unsupported matrix type
{
matrix
}
"
)
else
:
warp_fragment
=
base_fragment
.
repeat
([
warp_r
,
warp_s
],
repeat_on_thread
=
False
,
lower_dim_first
=
True
)
if
matrix_is_a
:
block_fragment
=
warp_fragment
.
repeat
([
1
,
block_s
],
repeat_on_thread
=
True
,
lower_dim_first
=
True
).
replicate
(
replicate
)
elif
matrix_is_b
:
block_fragment
=
warp_fragment
.
replicate
(
replicate
).
repeat
([
1
,
block_s
],
repeat_on_thread
=
True
,
lower_dim_first
=
True
)
else
:
raise
ValueError
(
f
"Unsupported matrix type
{
matrix
}
"
)
return
block_fragment
def
make_mfma_store_layout
(
self
,
local_buf
:
Buffer
)
->
T
.
Fragment
:
"""
Create a layout function for storing MFMA results into a fragment buffer.
Parameters
----------
local_buf : tir.Buffer
The local buffer representing a fragment of a matrix.
Returns
-------
T.Fragment
A fragment object that describes how threads and indices
in `local_buf` are laid out.
Raises
------
AssertionError
If `local_buf` is not detected to be a fragment buffer.
"""
from
tilelang.utils
import
is_fragment
shape
=
local_buf
.
shape
inverse_mfma_store_layout
=
self
.
get_store_index_map
(
inverse
=
True
)
assert
is_fragment
(
local_buf
),
"local_buf must be a fragment"
micro_size_x
,
micro_size_y
=
self
.
micro_size_x
,
self
.
micro_size_y
local_size_out
=
self
.
local_size_out
block_row_warps
,
block_col_warps
=
self
.
block_row_warps
,
self
.
block_col_warps
warp_rows
,
warp_cols
=
self
.
warp_rows
,
self
.
warp_cols
warp_size
=
self
.
WARP_SIZE
is_m_first
=
self
.
is_m_first
def
forward_thread
(
i
:
int
,
j
:
int
)
->
int
:
"""
Given the row index `i` and column index `j` in the fragment,
map them to a thread index according to `inverse_mfma_store_layout`.
"""
# the upper bounds of i and j are block_row_warps * warp_rows * micro_size_x and block_col_warps * warp_cols * micro_size_y
# the upper bounds of block_row_warps and block_col_warps are warp_rows and warp_cols
block_i
,
block_j
=
(
i
//
micro_size_x
)
//
warp_rows
,
(
j
//
micro_size_y
)
//
warp_cols
# upper bounds of mfma_i and mfma_j are micro_size_x and micro_size_y
mfma_i
,
mfma_j
=
i
%
micro_size_x
,
j
%
micro_size_y
lane_id
,
_
=
inverse_mfma_store_layout
.
map_indices
([
mfma_i
,
mfma_j
])
if
is_m_first
:
thread_id
=
block_i
*
(
block_col_warps
*
warp_cols
)
+
block_j
*
warp_size
+
lane_id
else
:
thread_id
=
block_j
*
(
block_row_warps
*
warp_size
)
+
block_i
*
warp_size
+
lane_id
return
thread_id
def
forward_index
(
i
:
int
,
j
:
int
)
->
int
:
"""
Given the row index `i` and column index `j` in the fragment,
map them to a local index in a single thread according
to `inverse_mfma_store_layout`.
"""
# the upper bounds of i and j are block_row_warps * warp_rows * micro_size_x and block_col_warps * warp_cols * micro_size_y
# the upper bounds of warp_i and warp_j are warp_rows and warp_cols
warp_i
,
warp_j
=
(
i
//
micro_size_x
)
%
warp_rows
,
(
j
//
micro_size_y
)
%
warp_cols
# upper bounds of mfma_i and mfma_j are micro_size_x and micro_size_y
mfma_i
,
mfma_j
=
i
%
micro_size_x
,
j
%
micro_size_y
_
,
local_id
=
inverse_mfma_store_layout
.
map_indices
([
mfma_i
,
mfma_j
])
return
warp_i
*
(
warp_cols
*
local_size_out
)
+
warp_j
*
local_size_out
+
local_id
return
T
.
Fragment
(
shape
,
forward_thread_fn
=
forward_thread
,
forward_index_fn
=
forward_index
,
)
class
MatrixCorePreshuffleIntrinEmitter
(
MatrixCoreIntrinEmitter
):
...
...
@@ -421,34 +676,27 @@ class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter):
is_m_first
:
bool
|
None
=
False
,
a_preshuffle
:
bool
|
None
=
False
,
b_preshuffle
:
bool
|
None
=
False
,
thread_var
:
Var
|
None
=
None
,
):
self
.
a_dtype
=
a_dtype
self
.
b_dtype
=
b_dtype
self
.
accum_dtype
=
accum_dtype
self
.
a_transposed
=
a_transposed
self
.
b_transposed
=
b_transposed
# Hint Information
self
.
block_row_warps
=
block_row_warps
self
.
block_col_warps
=
block_col_warps
self
.
warp_row_tiles
=
warp_row_tiles
self
.
warp_col_tiles
=
warp_col_tiles
self
.
chunk
=
chunk
self
.
_initialize_k_dim
(
a_dtype
)
self
.
_initialize_abbrev
(
a_dtype
,
b_dtype
,
accum_dtype
)
self
.
_initialize_local_size
(
self
.
M_DIM
,
self
.
N_DIM
,
self
.
k_dim
,
self
.
WARP_SIZE
)
self
.
_initialize_mfma_prefix
(
self
.
k_dim
)
self
.
_initialize_micro_size
(
self
.
M_DIM
,
self
.
N_DIM
,
self
.
k_dim
)
self
.
_initialize_k_pack
(
k_pack
)
self
.
_initialize_is_m_first
(
is_m_first
)
super
().
__init__
(
a_dtype
=
a_dtype
,
b_dtype
=
b_dtype
,
accum_dtype
=
accum_dtype
,
a_transposed
=
a_transposed
,
b_transposed
=
b_transposed
,
block_row_warps
=
block_row_warps
,
block_col_warps
=
block_col_warps
,
warp_row_tiles
=
warp_row_tiles
,
warp_col_tiles
=
warp_col_tiles
,
chunk
=
chunk
,
reduce_k
=
reduce_k
,
num_elems_per_byte
=
num_elems_per_byte
,
k_pack
=
k_pack
,
is_m_first
=
is_m_first
,
thread_var
=
thread_var
,
)
self
.
_initialize_preshuffle
(
a_preshuffle
,
b_preshuffle
)
self
.
warp_rows
=
warp_row_tiles
//
self
.
micro_size_x
self
.
warp_cols
=
warp_col_tiles
//
self
.
micro_size_y
self
.
reduce_k
=
reduce_k
self
.
threads
=
(
self
.
WARP_SIZE
*
(
block_row_warps
*
block_col_warps
)
*
reduce_k
)
self
.
num_elems_per_byte
=
num_elems_per_byte
def
_initialize_preshuffle
(
self
,
a_preshuffle
:
bool
,
b_preshuffle
:
bool
):
if
a_preshuffle
is
not
None
:
self
.
a_preshuffle
=
a_preshuffle
...
...
tilelang/intrinsics/mma_layout.py
View file @
bbbf4207
...
...
@@ -45,6 +45,12 @@ def mma_store_32x8_to_shared_16x16_layout(thread_id, local_id):
return
row
,
col
def
mma_store_32x2_to_shared_8x8_layout_fp64
(
thread_id
,
local_id
):
row
=
thread_id
//
4
col
=
(
thread_id
%
4
)
*
2
+
local_id
return
row
,
col
# sr represents spatial + reduction layout
# the first axis is spatial while the second axis is reduction
# mma.sync matrix A layout, if wanna trans, please apply map_indices
...
...
tilelang/intrinsics/mma_macro_generator.py
View file @
bbbf4207
...
...
@@ -3,13 +3,14 @@ import tilelang.language as T
from
typing
import
Literal
,
Callable
from
tilelang.common
import
TransformKind
from
tvm
import
DataType
from
tvm.tir
import
PrimExpr
,
IndexMap
,
Buffer
,
Var
from
tvm.tir
import
PrimExpr
,
IndexMap
,
Buffer
,
Var
,
BufferRegion
from
tilelang
import
tvm
as
tvm
from
tvm.runtime
import
convert
from
.utils
import
(
mma_store_index_map
,
get_ldmatrix_offset
,
)
from
tilelang.utils
import
is_fragment
from
tilelang.utils
import
is_fragment
,
to_buffer_region
from
tilelang.intrinsics.mma_layout
import
(
shared_16x8_to_mma_32x4_layout_sr_a
,
shared_16x8_to_mma_32x4_layout_sr_b
,
...
...
@@ -40,6 +41,7 @@ class TensorCoreIntrinEmitter:
"float16"
:
"fp16"
,
"bfloat16"
:
"bf16"
,
"float32"
:
"fp32"
,
"float64"
:
"fp64"
,
"int8"
:
"int8"
,
"int32"
:
"int32"
,
"float8_e4m3"
:
"e4m3"
,
...
...
@@ -78,6 +80,11 @@ class TensorCoreIntrinEmitter:
self
.
warp_col_tiles
=
warp_col_tiles
self
.
chunk
=
chunk
self
.
_initialize_k_dim
(
a_dtype
)
# For FP64, MMA shape is m8n8k4; adjust instance dims early
if
DataType
(
a_dtype
).
bits
==
64
:
# Override default M/N dims for fp64 MMA
self
.
M_DIM
=
8
# n_dim will be set to 8 in _initialize_micro_size via k_dim==4
self
.
_initialize_abbrev
(
a_dtype
,
b_dtype
,
accum_dtype
)
self
.
_initialize_micro_size
(
self
.
M_DIM
,
self
.
k_dim
)
self
.
_initialize_local_size
(
self
.
M_DIM
,
self
.
n_dim
,
self
.
k_dim
,
self
.
WARP_SIZE
)
...
...
@@ -105,12 +112,21 @@ class TensorCoreIntrinEmitter:
self
.
local_size_out
=
(
m_dim
*
n_dim
)
//
warp_size
def
_initialize_abbrev
(
self
,
a_dtype
,
b_dtype
,
accum_dtype
):
self
.
a_dtype_abbrv
=
self
.
dtype_abbrv
[
a_dtype
]
self
.
b_dtype_abbrv
=
self
.
dtype_abbrv
[
b_dtype
]
self
.
accum_dtype_abbrv
=
self
.
dtype_abbrv
[
accum_dtype
]
self
.
a_dtype_abbrv
=
self
.
_get_dtype_abbrv
(
a_dtype
)
self
.
b_dtype_abbrv
=
self
.
_get_dtype_abbrv
(
b_dtype
)
self
.
accum_dtype_abbrv
=
self
.
_get_dtype_abbrv
(
accum_dtype
)
def
_get_dtype_abbrv
(
self
,
dtype
:
str
)
->
str
:
try
:
return
self
.
dtype_abbrv
[
dtype
]
except
KeyError
as
err
:
raise
ValueError
(
f
"Unsupported dtype:
{
dtype
}
"
)
from
err
def
_initialize_mma_prefix
(
self
,
k_dim
:
int
=
16
):
if
k_dim
==
8
:
if
k_dim
==
4
:
# fp64
self
.
mma_prefix
=
"m8n8k4"
elif
k_dim
==
8
:
# typically used for tfloat32
self
.
mma_prefix
=
"m16n8k8"
elif
k_dim
==
16
:
...
...
@@ -125,6 +141,15 @@ class TensorCoreIntrinEmitter:
def
_initialize_micro_size
(
self
,
m_dim
:
int
=
16
,
k_dim
:
int
=
16
):
warp_row_tiles
=
self
.
warp_row_tiles
warp_col_tiles
=
self
.
warp_col_tiles
# For fp64 (k_dim==4), micro tile is 8x8, otherwise keep 16x{8|16}
if
k_dim
==
4
:
# fp64 path: m_dim must be 8, n_dim 8
assert
m_dim
==
8
,
f
"For fp64 MMA, m_dim must be 8, got
{
m_dim
}
"
self
.
n_dim
=
8
self
.
micro_size_y
=
8
self
.
warp_rows
=
warp_row_tiles
//
m_dim
self
.
warp_cols
=
warp_col_tiles
//
8
else
:
assert
warp_row_tiles
>=
16
,
f
"warp_row_tiles must be greater than 16, got
{
warp_row_tiles
}
"
assert
warp_row_tiles
%
16
==
0
,
f
"warp_row_tiles must be divisible by 16, got
{
warp_row_tiles
}
"
assert
warp_col_tiles
>=
8
,
f
"warp_col_tiles must be greater than 8, got
{
warp_col_tiles
}
"
...
...
@@ -158,7 +183,11 @@ class TensorCoreIntrinEmitter:
return
self
.
thread_var
def
get_store_index_map
(
self
,
inverse
:
bool
=
False
)
->
IndexMap
:
from
.utils
import
mma_store_index_map
,
mma_store_index_map_fp64
warp_size
,
local_size_c
=
self
.
WARP_SIZE
,
self
.
local_size_out
if
DataType
(
self
.
accum_dtype
).
bits
==
64
:
index_map
=
IndexMap
.
from_func
(
mma_store_index_map_fp64
,
index_dtype
=
"int32"
)
else
:
index_map
=
IndexMap
.
from_func
(
mma_store_index_map
,
index_dtype
=
"int32"
)
if
not
inverse
:
return
index_map
...
...
@@ -199,9 +228,47 @@ class TensorCoreIntrinEmitter:
def
ldmatrix_a
(
self
,
A_local_buf
:
Buffer
,
A_shared_buf
:
Buffer
,
A_shared_buf
:
Buffer
|
BufferRegion
,
ki
:
PrimExpr
,
rk
:
PrimExpr
|
None
=
0
):
# Fast path for fp64: no ldmatrix support, do direct per-lane loads
if
DataType
(
self
.
a_dtype
).
bits
==
64
:
warp_row_tiles
=
self
.
warp_row_tiles
warp_rows
=
self
.
warp_rows
chunk
=
self
.
chunk
micro_size_x
=
self
.
micro_size_x
# 8
micro_size_k
=
self
.
micro_size_k
# 4
local_size_a
=
self
.
local_size_a
# 1
a_transposed
=
self
.
a_transposed
thread_binding
=
self
.
get_thread_binding
()
# legalize shared buffer to region
A_region
=
to_buffer_region
(
A_shared_buf
)
A_buf
=
A_region
.
buffer
A_base0
=
A_region
.
region
[
-
2
].
min
A_base1
=
A_region
.
region
[
-
1
].
min
@
T
.
macro
def
_warp_ld_a_fp64
(
A_local_buf
,
A_shared_buf
,
ki
,
thread_binding
,
rk
=
0
,
):
tx
,
_
,
warp_m
=
self
.
extract_thread_binding
(
thread_binding
)
for
i
in
T
.
serial
(
warp_rows
):
wi
=
warp_m
*
warp_row_tiles
+
i
*
micro_size_x
wk
=
rk
*
chunk
+
ki
*
micro_size_k
mi
=
tx
//
micro_size_k
mk
=
tx
%
micro_size_k
if
a_transposed
:
A_local_buf
[
i
*
local_size_a
]
=
A_buf
[
A_base0
+
wk
+
mk
,
A_base1
+
wi
+
mi
]
else
:
A_local_buf
[
i
*
local_size_a
]
=
A_buf
[
A_base0
+
wi
+
mi
,
A_base1
+
wk
+
mk
]
return
_warp_ld_a_fp64
(
A_local_buf
,
A_region
,
ki
,
thread_binding
,
rk
)
warp_row_tiles
=
self
.
warp_row_tiles
warp_rows
=
self
.
warp_rows
chunk
=
self
.
chunk
...
...
@@ -226,6 +293,13 @@ class TensorCoreIntrinEmitter:
thread_binding
=
self
.
get_thread_binding
()
# legalize shared buffer to region
A_region
=
to_buffer_region
(
A_shared_buf
)
A_buf
=
A_region
.
buffer
A_base0
=
A_region
.
region
[
-
2
].
min
A_base1
=
A_region
.
region
[
-
1
].
min
A_stride_last
=
A_buf
.
shape
[
-
1
]
@
T
.
macro
def
_warp_ldmatrix_a
(
A_local_buf
,
...
...
@@ -234,14 +308,16 @@ class TensorCoreIntrinEmitter:
thread_binding
,
rk
=
0
,
):
stride
=
A_s
hared_buf
.
shape
[
-
1
]
stride
=
A_s
tride_last
tx
,
_
,
warp_m
=
self
.
extract_thread_binding
(
thread_binding
)
trans
=
self
.
a_transposed
for
i
in
T
.
serial
(
warp_rows
):
# Assign A_shared_buf_elem
wi
,
wk
=
warp_m
*
warp_row_tiles
+
i
*
micro_size_x
,
rk
*
chunk
+
ki
*
micro_size_k
A_shared_buf_elem
=
A_shared_buf
[
wk
,
wi
]
if
a_transposed
else
A_shared_buf
[
wi
,
wk
]
A_shared_buf_elem
=
A_buf
[
A_base0
+
wk
,
A_base1
+
wi
]
if
a_transposed
else
A_buf
[
A_base0
+
wi
,
A_base1
+
wk
]
if
ldmatrix_available
:
T
.
ptx_ldmatrix
(
...
...
@@ -257,15 +333,59 @@ class TensorCoreIntrinEmitter:
else
:
for
j
in
T
.
serial
(
local_size_a
):
mi
,
mk
=
mma_load_layout
(
tx
,
j
)
A_local_buf
[
i
*
local_size_a
+
j
]
=
A_shared_buf
[
wk
+
mk
,
wi
+
mi
]
if
a_transposed
:
A_local_buf
[
i
*
local_size_a
+
j
]
=
A_buf
[
A_base0
+
wk
+
mk
,
A_base1
+
wi
+
mi
]
else
:
A_local_buf
[
i
*
local_size_a
+
j
]
=
A_buf
[
A_base0
+
wi
+
mi
,
A_base1
+
wk
+
mk
]
return
_warp_ldmatrix_a
(
A_local_buf
,
A_
shared_buf
,
ki
,
thread_binding
,
rk
)
return
_warp_ldmatrix_a
(
A_local_buf
,
A_
region
,
ki
,
thread_binding
,
rk
)
def
ldmatrix_b
(
self
,
B_local_buf
:
Buffer
,
B_shared_buf
:
Buffer
,
B_shared_buf
:
Buffer
|
BufferRegion
,
ki
:
PrimExpr
,
rk
:
PrimExpr
|
None
=
0
):
# Fast path for fp64: no ldmatrix support, do direct per-lane loads
if
DataType
(
self
.
b_dtype
).
bits
==
64
:
warp_col_tiles
=
self
.
warp_col_tiles
warp_cols
=
self
.
warp_cols
chunk
=
self
.
chunk
micro_size_y
=
self
.
micro_size_y
# 8
micro_size_k
=
self
.
micro_size_k
# 4
local_size_b
=
self
.
local_size_b
# 1
b_transposed
=
self
.
b_transposed
thread_binding
=
self
.
get_thread_binding
()
# legalize shared buffer to region
B_region
=
to_buffer_region
(
B_shared_buf
)
B_buf
=
B_region
.
buffer
B_base0
=
B_region
.
region
[
-
2
].
min
B_base1
=
B_region
.
region
[
-
1
].
min
@
T
.
macro
def
_warp_ld_b_fp64
(
B_local_buf
,
B_shared_buf
,
ki
,
thread_binding
,
rk
=
0
,
):
tx
,
warp_n
,
_
=
self
.
extract_thread_binding
(
thread_binding
)
for
j
in
T
.
serial
(
warp_cols
):
wi
=
warp_n
*
warp_col_tiles
+
j
*
micro_size_y
wk
=
rk
*
chunk
+
ki
*
micro_size_k
mi
=
tx
//
micro_size_k
mk
=
tx
%
micro_size_k
if
b_transposed
:
B_local_buf
[
j
*
local_size_b
]
=
B_buf
[
B_base0
+
wi
+
mi
,
B_base1
+
wk
+
mk
]
else
:
B_local_buf
[
j
*
local_size_b
]
=
B_buf
[
B_base0
+
wk
+
mk
,
B_base1
+
wi
+
mi
]
return
_warp_ld_b_fp64
(
B_local_buf
,
B_region
,
ki
,
thread_binding
,
rk
)
warp_col_tiles
=
self
.
warp_col_tiles
warp_cols
=
self
.
warp_cols
chunk
=
self
.
chunk
...
...
@@ -275,6 +395,13 @@ class TensorCoreIntrinEmitter:
b_dtype
=
self
.
b_dtype
b_transposed
=
self
.
b_transposed
thread_binding
=
self
.
get_thread_binding
()
# legalize shared buffer to region
B_region
=
to_buffer_region
(
B_shared_buf
)
B_buf
=
B_region
.
buffer
B_base0
=
B_region
.
region
[
-
2
].
min
B_base1
=
B_region
.
region
[
-
1
].
min
B_stride_last
=
B_buf
.
shape
[
-
1
]
replicate_b
=
(
self
.
n_dim
==
16
)
# ldmatrix cannot be used for int8 + trans case.
ldmatrix_available
=
not
(
DataType
(
b_dtype
).
bits
!=
16
and
not
b_transposed
)
...
...
@@ -298,7 +425,7 @@ class TensorCoreIntrinEmitter:
thread_binding
,
rk
=
0
,
):
stride
=
B_s
hared_buf
.
shape
[
-
1
]
stride
=
B_s
tride_last
tx
,
warp_n
,
_
=
self
.
extract_thread_binding
(
thread_binding
)
trans
=
not
b_transposed
...
...
@@ -310,8 +437,9 @@ class TensorCoreIntrinEmitter:
)
if
ldmatrix_available
:
B_shared_buf_elem
=
B_shared_buf
[
wi
,
wk
]
if
b_transposed
else
B_shared_buf
[
wk
,
wi
]
B_shared_buf_elem
=
B_buf
[
B_base0
+
wi
,
B_base1
+
wk
]
if
b_transposed
else
B_buf
[
B_base0
+
wk
,
B_base1
+
wi
]
T
.
ptx_ldmatrix
(
b_dtype
,
...
...
@@ -329,7 +457,12 @@ class TensorCoreIntrinEmitter:
# must be transposed.
for
j
in
T
.
serial
(
local_size_b
):
mi
,
mk
=
mma_load_layout
(
tx
,
j
)
B_local_buf
[
i
*
local_size_b
+
j
]
=
B_shared_buf
[
wk
+
mk
,
wi
+
mi
]
if
b_transposed
:
B_local_buf
[
i
*
local_size_b
+
j
]
=
B_buf
[
B_base0
+
wi
+
mi
,
B_base1
+
wk
+
mk
]
else
:
B_local_buf
[
i
*
local_size_b
+
j
]
=
B_buf
[
B_base0
+
wk
+
mk
,
B_base1
+
wi
+
mi
]
return
_warp_ldmatrix_b
(
B_local_buf
,
B_shared_buf
,
ki
,
thread_binding
,
rk
)
...
...
@@ -617,8 +750,10 @@ class TensorCoreIntrinEmitter:
from
tilelang.utils
import
is_fragment
shape
=
local_buf
.
shape
assert
is_fragment
(
local_buf
),
f
"local_buf
{
local_buf
}
must be a fragment, but got
{
local_buf
.
scope
()
}
"
inverse_mma_store_layout
=
self
.
get_store_index_map
(
inverse
=
True
)
assert
is_fragment
(
local_buf
),
"local_buf must be a fragment"
micro_size_x
,
micro_size_y
=
self
.
micro_size_x
,
self
.
micro_size_y
local_size_out
=
self
.
local_size_out
block_row_warps
,
block_col_warps
=
self
.
block_row_warps
,
self
.
block_col_warps
...
...
tilelang/intrinsics/mma_sm70_layout.py
0 → 100644
View file @
bbbf4207
from
__future__
import
annotations
def
shared_16x4_to_mma_a_32x4_layout
(
row
,
col
,
rep
):
tid
=
(
row
%
4
)
+
16
*
((
row
//
4
)
%
2
)
+
4
*
(
row
//
8
)
+
8
*
rep
local_id
=
col
return
tid
,
local_id
def
shared_4x16_to_mma_b_32x4_layout
(
row
,
col
,
rep
):
thread_id
=
row
+
8
*
col
//
4
+
4
*
rep
local_id
=
col
%
4
return
thread_id
,
local_id
def
shared_16x4_to_mma_b_32x4_layout_trans
(
row
,
col
,
rep
):
thread_id
=
row
%
4
+
4
*
rep
+
8
*
((
row
%
8
)
//
4
)
+
16
*
(
row
//
8
)
local_id
=
col
return
thread_id
,
local_id
def
mma_32x8_to_shared_16x16_layout_fp32
(
thread_id
,
local_id
):
row
=
(
thread_id
%
2
)
+
(
(
local_id
//
2
%
2
)
*
2
)
+
4
*
(
thread_id
//
16
)
+
(
thread_id
%
16
//
4
)
%
2
*
8
col
=
(
thread_id
%
4
//
2
)
*
2
+
(
thread_id
%
16
//
8
)
*
4
+
(
local_id
%
2
)
+
(
local_id
//
4
)
*
8
return
row
,
col
def
mma_32x8_to_shared_16x16_layout_fp16
(
thread_id
,
local_id
):
row
=
(
thread_id
%
4
)
+
(
thread_id
//
16
)
*
4
+
(
thread_id
%
8
)
//
4
*
8
col
=
local_id
%
4
+
((
thread_id
%
16
)
//
8
)
*
4
+
(
local_id
//
4
)
*
8
return
row
,
col
def
mma_load_a_32x4_to_shared_16x4_layout
(
thread_id
,
local_id
):
row
=
(
thread_id
%
4
)
+
(
4
*
(((
thread_id
//
16
+
thread_id
%
16
//
4
*
2
))
%
4
))
col
=
local_id
return
row
,
col
def
mma_load_b_32x4_to_shared_16x4_layout_trans
(
thread_id
,
local_id
):
row
=
(
thread_id
%
4
)
+
8
*
(
thread_id
//
16
)
+
4
*
((
thread_id
//
8
)
%
2
)
col
=
local_id
return
row
,
col
def
mma_load_b_32x4_to_shared_4x16_layout
(
thread_id
,
local_id
):
row
=
thread_id
%
4
col
=
local_id
+
(
4
*
(
thread_id
//
8
))
return
row
,
col
tilelang/intrinsics/mma_sm70_macro_generator.py
0 → 100644
View file @
bbbf4207
from
__future__
import
annotations
import
tilelang.language
as
T
from
typing
import
Literal
,
Callable
from
tvm
import
DataType
from
tvm.tir
import
PrimExpr
,
IndexMap
,
Buffer
,
Var
,
BufferRegion
from
tilelang
import
tvm
as
tvm
from
tvm.runtime
import
convert
from
tilelang.utils
import
is_fragment
,
to_buffer_region
from
tilelang.intrinsics.mma_sm70_layout
import
(
shared_16x4_to_mma_a_32x4_layout
,
shared_4x16_to_mma_b_32x4_layout
,
shared_16x4_to_mma_b_32x4_layout_trans
,
mma_32x8_to_shared_16x16_layout_fp32
,
mma_32x8_to_shared_16x16_layout_fp16
,
mma_load_a_32x4_to_shared_16x4_layout
,
mma_load_b_32x4_to_shared_16x4_layout_trans
,
mma_load_b_32x4_to_shared_4x16_layout
,
)
lift
=
convert
class
TensorCoreIntrinEmitter
:
"""
To eliminate Python syntax within TIR Macro.
"""
M_DIM
=
16
# use lowercase as n_dim can be dynamic
# the smallest instructions can be m16n8k16, so the n_dim can also be 8
n_dim
=
16
WARP_SIZE
=
32
HALF_WARP_SIZE
=
WARP_SIZE
//
2
dtype_abbrv
=
{
"float16"
:
"fp16"
,
"bfloat16"
:
"bf16"
,
"float32"
:
"fp32"
,
"int8"
:
"int8"
,
"int32"
:
"int32"
,
"float8_e4m3"
:
"e4m3"
,
"float8_e5m2"
:
"e5m2"
,
}
# Represent the thread binding in the form of (tx, warp_n, warp_m)
is_m_first
=
False
def
__init__
(
self
,
a_dtype
:
str
=
"float16"
,
b_dtype
:
str
=
"float16"
,
accum_dtype
:
str
=
"float16"
,
a_transposed
:
bool
=
False
,
b_transposed
:
bool
=
False
,
block_row_warps
:
int
=
2
,
block_col_warps
:
int
=
2
,
warp_row_tiles
:
int
=
8
,
warp_col_tiles
:
int
=
8
,
chunk
:
int
=
16
,
reduce_k
:
int
=
1
,
num_elems_per_byte
:
int
=
1
,
is_m_first
:
bool
|
None
=
False
,
thread_var
:
Var
|
None
=
None
,
):
self
.
a_dtype
=
a_dtype
self
.
b_dtype
=
b_dtype
self
.
accum_dtype
=
accum_dtype
self
.
a_transposed
=
a_transposed
self
.
b_transposed
=
b_transposed
# Hint Information
self
.
block_row_warps
=
block_row_warps
self
.
block_col_warps
=
block_col_warps
self
.
warp_row_tiles
=
warp_row_tiles
self
.
warp_col_tiles
=
warp_col_tiles
self
.
chunk
=
chunk
self
.
_initialize_k_dim
(
a_dtype
)
self
.
_initialize_abbrev
(
a_dtype
,
b_dtype
,
accum_dtype
)
self
.
_initialize_micro_size
(
self
.
M_DIM
,
self
.
k_dim
)
self
.
_initialize_local_size
(
self
.
M_DIM
,
self
.
n_dim
,
self
.
k_dim
)
self
.
_initialize_mma_prefix
(
self
.
k_dim
)
self
.
_initialize_is_m_first
(
is_m_first
)
self
.
reduce_k
=
reduce_k
self
.
threads
=
self
.
WARP_SIZE
*
(
block_row_warps
*
block_col_warps
)
*
reduce_k
self
.
num_elems_per_byte
=
num_elems_per_byte
self
.
thread_var
=
thread_var
if
self
.
warp_rows
==
0
or
self
.
warp_cols
==
0
:
raise
ValueError
(
f
"Invalid threads configuration for this tile shape,
{
self
.
warp_rows
}
x
{
self
.
warp_cols
}
with threads
{
self
.
threads
}
"
)
def
_initialize_k_dim
(
self
,
a_dtype
=
"float16"
):
self
.
k_dim
=
4
def
_initialize_local_size
(
self
,
m_dim
=
16
,
n_dim
=
16
,
k_dim
=
16
):
self
.
local_size_a
=
(
m_dim
*
k_dim
)
//
self
.
HALF_WARP_SIZE
self
.
local_size_b
=
(
n_dim
*
k_dim
)
//
self
.
HALF_WARP_SIZE
self
.
local_size_out
=
(
m_dim
*
n_dim
)
//
self
.
WARP_SIZE
def
_initialize_abbrev
(
self
,
a_dtype
,
b_dtype
,
accum_dtype
):
self
.
a_dtype_abbrv
=
self
.
_get_dtype_abbrv
(
a_dtype
)
self
.
b_dtype_abbrv
=
self
.
_get_dtype_abbrv
(
b_dtype
)
self
.
accum_dtype_abbrv
=
self
.
_get_dtype_abbrv
(
accum_dtype
)
def
_get_dtype_abbrv
(
self
,
dtype
:
str
)
->
str
:
try
:
return
self
.
dtype_abbrv
[
dtype
]
except
KeyError
as
err
:
raise
ValueError
(
f
"Unsupported dtype:
{
dtype
}
"
)
from
err
def
_initialize_mma_prefix
(
self
,
k_dim
:
int
=
16
):
if
k_dim
==
4
:
# typically used for float16
self
.
mma_prefix
=
"m16n16k4"
else
:
raise
ValueError
(
f
"Unsupported k_dim:
{
k_dim
}
"
)
def
_initialize_micro_size
(
self
,
m_dim
:
int
=
16
,
k_dim
:
int
=
16
):
warp_row_tiles
=
self
.
warp_row_tiles
warp_col_tiles
=
self
.
warp_col_tiles
assert
warp_row_tiles
>=
16
,
f
"warp_row_tiles must be greater than 16, got
{
warp_row_tiles
}
"
assert
warp_row_tiles
%
16
==
0
,
f
"warp_row_tiles must be divisible by 16, got
{
warp_row_tiles
}
"
assert
warp_col_tiles
>=
16
,
f
"warp_col_tiles must be greater than 16, got
{
warp_col_tiles
}
"
assert
warp_col_tiles
%
16
==
0
,
f
"warp_col_tiles must be divisible by 16, got
{
warp_col_tiles
}
"
self
.
warp_rows
=
warp_row_tiles
//
m_dim
self
.
n_dim
=
16
self
.
micro_size_y
=
16
self
.
warp_cols
=
warp_col_tiles
//
16
self
.
micro_size_x
=
m_dim
self
.
micro_size_k
=
k_dim
def
_initialize_is_m_first
(
self
,
is_m_first
:
bool
|
None
=
False
):
if
is_m_first
is
not
None
:
self
.
is_m_first
=
is_m_first
def
get_thread_binding
(
self
):
if
self
.
thread_var
is
None
:
current_frame
=
T
.
KernelLaunchFrame
.
Current
()
assert
current_frame
is
not
None
,
"Must be called in a T.Kernel Frame"
return
current_frame
.
get_thread_binding
()
else
:
return
self
.
thread_var
def
get_store_index_map
(
self
,
inverse
:
bool
=
False
)
->
IndexMap
:
warp_size
,
local_size_c
=
self
.
WARP_SIZE
,
self
.
local_size_out
index_map
=
IndexMap
.
from_func
(
mma_32x8_to_shared_16x16_layout_fp32
if
self
.
accum_dtype
==
"float32"
else
mma_32x8_to_shared_16x16_layout_fp16
,
index_dtype
=
"int32"
)
if
not
inverse
:
return
index_map
inverse_index_map
=
index_map
.
inverse
([
warp_size
,
local_size_c
])
return
inverse_index_map
def
extract_thread_binding
(
self
,
thread_id
:
PrimExpr
,
is_m_first
:
bool
|
None
=
None
)
->
tuple
[
PrimExpr
,
PrimExpr
,
PrimExpr
]:
"""
is_m_first: True if the thread binding is in the form of (tx, warp_n, warp_m)
which represents [warp_size, block_row_warps (split n), block_col_warps (split m)]
Otherwise, it is in the form of [warp_size, block_col_warps (split m), block_row_warps (split n)]
"""
WARP_SIZE
=
self
.
WARP_SIZE
block_row_warps
=
self
.
block_row_warps
block_col_warps
=
self
.
block_col_warps
# if is_m_first is None, then use the default value
if
is_m_first
is
None
:
is_m_first
=
self
.
is_m_first
if
is_m_first
:
lane_id
,
warp_n
,
warp_m
=
(
thread_id
%
WARP_SIZE
,
(
thread_id
//
WARP_SIZE
)
%
block_col_warps
,
(
thread_id
//
(
WARP_SIZE
*
block_col_warps
))
%
block_row_warps
,
)
return
lane_id
,
warp_n
,
warp_m
else
:
lane_id
,
warp_m
,
warp_n
=
(
thread_id
%
WARP_SIZE
,
(
thread_id
//
WARP_SIZE
)
%
block_row_warps
,
(
thread_id
//
(
WARP_SIZE
*
block_row_warps
))
%
block_col_warps
,
)
return
lane_id
,
warp_n
,
warp_m
def
ldmatrix_a
(
self
,
A_local_buf
:
Buffer
,
A_shared_buf
:
Buffer
|
BufferRegion
,
ki
:
PrimExpr
,
rk
:
PrimExpr
|
None
=
0
):
warp_row_tiles
=
self
.
warp_row_tiles
warp_rows
=
self
.
warp_rows
chunk
=
self
.
chunk
micro_size_x
=
self
.
micro_size_x
micro_size_k
=
self
.
micro_size_k
local_size_a
=
self
.
local_size_a
a_transposed
=
self
.
a_transposed
thread_binding
=
self
.
get_thread_binding
()
assert
not
a_transposed
,
"A must be not transposed"
mma_load_layout
=
mma_load_a_32x4_to_shared_16x4_layout
# legalize shared buffer to region
A_region
=
to_buffer_region
(
A_shared_buf
)
A_buf
=
A_region
.
buffer
A_base0
=
A_region
.
region
[
-
2
].
min
A_base1
=
A_region
.
region
[
-
1
].
min
@
T
.
macro
def
_warp_ldmatrix_a
(
A_local_buf
,
A_shared_buf
,
ki
,
thread_binding
,
rk
=
0
,
):
tx
,
_
,
warp_m
=
self
.
extract_thread_binding
(
thread_binding
)
for
i
in
T
.
serial
(
warp_rows
):
# Assign A_shared_buf_elem
wi
,
wk
=
warp_m
*
warp_row_tiles
+
i
*
micro_size_x
,
rk
*
chunk
+
ki
*
micro_size_k
for
j
in
T
.
vectorized
(
local_size_a
):
mi
,
mk
=
mma_load_layout
(
tx
,
j
)
A_local_buf
[
i
*
local_size_a
+
j
]
=
A_buf
[
A_base0
+
wi
+
mi
,
A_base1
+
wk
+
mk
]
return
_warp_ldmatrix_a
(
A_local_buf
,
A_region
,
ki
,
thread_binding
,
rk
)
def
ldmatrix_b
(
self
,
B_local_buf
:
Buffer
,
B_shared_buf
:
Buffer
|
BufferRegion
,
ki
:
PrimExpr
,
rk
:
PrimExpr
|
None
=
0
):
warp_col_tiles
=
self
.
warp_col_tiles
warp_cols
=
self
.
warp_cols
chunk
=
self
.
chunk
micro_size_y
=
self
.
micro_size_y
micro_size_k
=
self
.
micro_size_k
local_size_b
=
self
.
local_size_b
b_transposed
=
self
.
b_transposed
thread_binding
=
self
.
get_thread_binding
()
mma_load_layout
=
mma_load_b_32x4_to_shared_16x4_layout_trans
if
b_transposed
else
mma_load_b_32x4_to_shared_4x16_layout
# legalize shared buffer to region
B_region
=
to_buffer_region
(
B_shared_buf
)
B_buf
=
B_region
.
buffer
B_base0
=
B_region
.
region
[
-
2
].
min
B_base1
=
B_region
.
region
[
-
1
].
min
@
T
.
macro
def
_warp_ldmatrix_b
(
B_local_buf
,
B_shared_buf
,
ki
,
thread_binding
,
rk
=
0
,
):
tx
,
warp_n
,
_
=
self
.
extract_thread_binding
(
thread_binding
)
for
i
in
T
.
serial
(
warp_cols
):
# Assign B_shared_elem
wi
,
wk
=
(
warp_n
*
warp_col_tiles
+
i
*
micro_size_y
,
rk
*
chunk
+
ki
*
micro_size_k
,
)
# load 16x32 data from shared buffer to local buffer
# must be transposed.
for
j
in
T
.
vectorized
(
local_size_b
):
if
b_transposed
:
mi
,
mk
=
mma_load_layout
(
tx
,
j
)
B_local_buf
[
i
*
local_size_b
+
j
]
=
B_buf
[
B_base0
+
wi
+
mi
,
B_base1
+
wk
+
mk
]
else
:
mk
,
mi
=
mma_load_layout
(
tx
,
j
)
B_local_buf
[
i
*
local_size_b
+
j
]
=
B_buf
[
B_base0
+
wk
+
mk
,
B_base1
+
wi
+
mi
]
return
_warp_ldmatrix_b
(
B_local_buf
,
B_region
,
ki
,
thread_binding
,
rk
)
def
mma
(
self
,
A_local_buf
:
Buffer
,
B_local_buf
:
Buffer
,
C_local_buf
:
Buffer
,
k_inner
:
PrimExpr
|
None
=
0
):
warp_rows
=
self
.
warp_rows
warp_cols
=
self
.
warp_cols
local_size_a
=
self
.
local_size_a
local_size_b
=
self
.
local_size_b
local_size_out
=
self
.
local_size_out
a_dtype_abbrv
=
self
.
a_dtype_abbrv
b_dtype_abbrv
=
self
.
b_dtype_abbrv
accum_dtype_abbrv
=
self
.
accum_dtype_abbrv
mma_prefix
=
self
.
mma_prefix
a_is_fragment
=
is_fragment
(
A_local_buf
)
b_is_fragment
=
is_fragment
(
B_local_buf
)
a_local_stride
:
PrimExpr
=
k_inner
*
warp_rows
*
local_size_a
if
a_is_fragment
else
0
b_local_stride
:
PrimExpr
=
k_inner
*
warp_cols
*
local_size_b
if
b_is_fragment
else
0
a_major
=
"col"
if
self
.
a_transposed
else
"row"
b_major
=
"col"
if
self
.
b_transposed
else
"row"
@
T
.
macro
def
_warp_mma
(
A_local_buf
,
B_local_buf
,
C_local_buf
):
for
i
,
j
in
T
.
grid
(
warp_rows
,
warp_cols
):
T
.
ptx_mma_sm70
(
mma_prefix
,
a_major
,
b_major
,
a_dtype_abbrv
,
b_dtype_abbrv
,
accum_dtype_abbrv
,
A_local_buf
.
data
,
a_local_stride
+
i
*
local_size_a
,
B_local_buf
.
data
,
b_local_stride
+
j
*
local_size_b
,
C_local_buf
.
data
,
i
*
warp_cols
*
local_size_out
+
j
*
local_size_out
,
)
return
_warp_mma
(
A_local_buf
,
B_local_buf
,
C_local_buf
)
def
make_mma_load_layout
(
self
,
local_buf
:
Buffer
,
matrix
:
Literal
[
"A"
,
"B"
]
=
"A"
)
->
T
.
Fragment
:
"""
Create a layout function for storing MMA results into a fragment buffer.
This layout is used in conjunction with `inverse_mma_store_layout` to
map fragment indices to threads and local indices.
Parameters
----------
local_buf : tir.Buffer
The local buffer representing a fragment of a matrix.
Returns
-------
T.Fragment
A fragment object that describes how threads and indices
in `local_buf` are laid out.
Raises
------
AssertionError
If `local_buf` is not detected to be a fragment buffer.
"""
from
tilelang.utils
import
is_fragment
assert
matrix
in
[
"A"
,
"B"
],
"matrix should be either A or B"
matrix_is_a
:
bool
=
matrix
==
"A"
matrix_is_b
:
bool
=
matrix
==
"B"
dtype
=
self
.
a_dtype
if
matrix_is_a
else
self
.
b_dtype
dtype_bits
=
DataType
(
dtype
).
bits
transposed
=
self
.
a_transposed
if
matrix_is_a
else
self
.
b_transposed
# s represents spatial axis
# r represents reduction axis
# sr represents the two dims are spatial + reduction
# rs represents the two dims are reduction + spatial
# sr also can represent a non-transposed basic layout
# then rs also can represent a transposed basic layout
transform_func_sr_a
:
Callable
=
None
transform_func_sr_b
:
Callable
=
None
transform_func_rs_b
:
Callable
=
None
if
dtype_bits
==
16
:
transform_func_sr_a
=
shared_16x4_to_mma_a_32x4_layout
transform_func_sr_b
=
shared_16x4_to_mma_b_32x4_layout_trans
transform_func_rs_b
=
shared_4x16_to_mma_b_32x4_layout
else
:
raise
ValueError
(
f
"Unsupported dtype
{
dtype
}
"
)
is_sr_conditions
=
[
False
]
is_sr_conditions
.
append
(
matrix_is_a
and
not
transposed
)
is_sr_conditions
.
append
(
matrix_is_b
and
transposed
)
is_sr_axis_order
=
any
(
is_sr_conditions
)
# the layout of mma.sync is row.col.
# so the b matrix expected a transposed basic layout
transform_func
:
Callable
=
None
if
matrix_is_a
:
transform_func
=
transform_func_sr_a
if
is_sr_axis_order
else
lambda
i
,
j
:
transform_func_sr_a
(
j
,
i
)
elif
matrix_is_b
:
transform_func
=
transform_func_sr_b
if
is_sr_axis_order
else
lambda
i
,
j
:
transform_func_rs_b
(
i
,
j
)
else
:
raise
ValueError
(
f
"Unsupported matrix
{
matrix
}
"
)
assert
is_fragment
(
local_buf
),
f
"local_buf must be a fragment, but got
{
local_buf
.
scope
()
}
"
if
matrix_is_a
:
micro_size_s
,
micro_size_r
=
self
.
micro_size_x
,
self
.
micro_size_k
else
:
micro_size_r
,
micro_size_s
=
self
.
micro_size_k
,
self
.
micro_size_y
block_row_warps
,
block_col_warps
=
(
self
.
block_row_warps
,
self
.
block_col_warps
,
)
inverse_mma_load_layout
=
IndexMap
.
from_func
(
transform_func
,
index_dtype
=
"int32"
)
def
forward
(
i
:
int
,
j
:
int
,
rep
:
int
)
->
int
:
"""
Given the row index `i` and column index `j` in the fragment,
"""
lane_id
,
local_id
=
inverse_mma_load_layout
.
map_indices
([
i
,
j
,
rep
])
return
lane_id
,
local_id
base_fragment
=
T
.
Fragment
(
[
micro_size_s
,
micro_size_r
]
if
is_sr_axis_order
else
[
micro_size_r
,
micro_size_s
],
forward_fn
=
forward
,
replicate
=
2
)
warp_rows
,
warp_cols
=
self
.
warp_rows
,
self
.
warp_cols
chunk
=
self
.
chunk
warp_s
=
warp_rows
if
matrix_is_a
else
warp_cols
warp_r
=
chunk
//
micro_size_r
block_s
=
block_row_warps
if
matrix_is_a
else
block_col_warps
replicate
=
block_col_warps
if
matrix_is_a
else
block_row_warps
if
is_sr_axis_order
:
warp_fragment
=
base_fragment
.
repeat
([
warp_s
,
warp_r
],
repeat_on_thread
=
False
,
lower_dim_first
=
False
)
if
matrix_is_a
:
block_fragment
=
warp_fragment
.
repeat
([
block_s
,
1
],
repeat_on_thread
=
True
,
lower_dim_first
=
True
).
replicate
(
replicate
)
elif
matrix_is_b
:
block_fragment
=
warp_fragment
.
replicate
(
replicate
).
repeat
([
block_s
,
1
],
repeat_on_thread
=
True
,
lower_dim_first
=
True
)
else
:
raise
ValueError
(
f
"Unsupported matrix type
{
matrix
}
"
)
else
:
warp_fragment
=
base_fragment
.
repeat
([
warp_r
,
warp_s
],
repeat_on_thread
=
False
,
lower_dim_first
=
True
)
if
matrix_is_a
:
block_fragment
=
warp_fragment
.
repeat
([
1
,
block_s
],
repeat_on_thread
=
True
,
lower_dim_first
=
True
).
replicate
(
replicate
)
elif
matrix_is_b
:
block_fragment
=
warp_fragment
.
replicate
(
replicate
).
repeat
([
1
,
block_s
],
repeat_on_thread
=
True
,
lower_dim_first
=
True
)
else
:
raise
ValueError
(
f
"Unsupported matrix type
{
matrix
}
"
)
return
block_fragment
def
make_mma_store_layout
(
self
,
local_buf
:
Buffer
)
->
T
.
Fragment
:
"""
Create a layout function for storing MMA results into a fragment buffer.
This layout is used in conjunction with `inverse_mma_store_layout` to
map fragment indices to threads and local indices.
Parameters
----------
local_buf : tir.Buffer
The local buffer representing a fragment of a matrix.
Returns
-------
T.Fragment
A fragment object that describes how threads and indices
in `local_buf` are laid out.
Raises
------
AssertionError
If `local_buf` is not detected to be a fragment buffer.
"""
from
tilelang.utils
import
is_fragment
shape
=
local_buf
.
shape
inverse_mma_store_layout
=
self
.
get_store_index_map
(
inverse
=
True
)
assert
is_fragment
(
local_buf
),
"local_buf must be a fragment"
micro_size_x
,
micro_size_y
=
self
.
micro_size_x
,
self
.
micro_size_y
local_size_out
=
self
.
local_size_out
block_row_warps
,
block_col_warps
=
self
.
block_row_warps
,
self
.
block_col_warps
warp_rows
,
warp_cols
=
self
.
warp_rows
,
self
.
warp_cols
warp_size
=
self
.
WARP_SIZE
is_m_first
=
self
.
is_m_first
def
forward_thread
(
i
:
int
,
j
:
int
)
->
int
:
"""
Given the row index `i` and column index `j` in the fragment,
map them to a thread index according to `inverse_mma_store_layout`.
"""
# the upper bounds of i and j are block_row_warps * warp_rows * micro_size_x and block_col_warps * warp_cols * micro_size_y
# the upper bounds of block_row_warps and block_col_warps are warp_rows and warp_cols
block_i
,
block_j
=
(
i
//
micro_size_x
)
//
warp_rows
,
(
j
//
micro_size_y
)
//
warp_cols
# upper bounds of mma_i and mma_j are micro_size_x and micro_size_y
mma_i
,
mma_j
=
i
%
micro_size_x
,
j
%
micro_size_y
lane_id
,
_
=
inverse_mma_store_layout
.
map_indices
([
mma_i
,
mma_j
])
if
is_m_first
:
thread_id
=
block_i
*
(
block_col_warps
*
warp_cols
)
+
block_j
*
warp_size
+
lane_id
else
:
thread_id
=
block_j
*
(
block_row_warps
*
warp_size
)
+
block_i
*
warp_size
+
lane_id
return
thread_id
def
forward_index
(
i
:
int
,
j
:
int
)
->
int
:
"""
Given the row index `i` and column index `j` in the fragment,
map them to a local index in a single thread according
to `inverse_mma_store_layout`.
"""
# the upper bounds of i and j are block_row_warps * warp_rows * micro_size_x and block_col_warps * warp_cols * micro_size_y
# the upper bounds of warp_i and warp_j are warp_rows and warp_cols
warp_i
,
warp_j
=
(
i
//
micro_size_x
)
%
warp_rows
,
(
j
//
micro_size_y
)
%
warp_cols
# upper bounds of mma_i and mma_j are micro_size_x and micro_size_y
mma_i
,
mma_j
=
i
%
micro_size_x
,
j
%
micro_size_y
_
,
local_id
=
inverse_mma_store_layout
.
map_indices
([
mma_i
,
mma_j
])
return
warp_i
*
(
warp_cols
*
local_size_out
)
+
warp_j
*
local_size_out
+
local_id
return
T
.
Fragment
(
shape
,
forward_thread_fn
=
forward_thread
,
forward_index_fn
=
forward_index
,
)
tilelang/intrinsics/tcgen05_macro_generator.py
0 → 100644
View file @
bbbf4207
from
__future__
import
annotations
from
enum
import
IntEnum
import
tilelang.language
as
T
from
.mma_macro_generator
import
TensorCoreIntrinEmitter
as
MMAIntrinEmitter
from
tvm
import
DataType
from
tvm.tir
import
PrimExpr
,
Buffer
,
Var
,
BufferLoad
,
BufferRegion
from
tilelang
import
tvm
as
tvm
from
tilelang
import
_ffi_api
from
tilelang.utils
import
is_tensor_memory
from
tilelang.layout
import
(
Layout
,
make_full_bank_swizzled_layout
,
make_half_bank_swizzled_layout
,
make_quarter_bank_swizzled_layout
,
make_linear_layout
,
)
from
tvm.runtime
import
convert
lift
=
convert
class
SwizzleMode
(
IntEnum
):
# SWIZZLE_NONE = 0, SWIZZLE_32B = 3, SWIZZLE_64B = 2, SWIZZLE_128B = 1
NONE
=
0
SWIZZLE_128B
=
2
SWIZZLE_64B
=
4
SWIZZLE_32B
=
6
def
is_none
(
self
)
->
bool
:
return
self
==
SwizzleMode
.
NONE
def
is_swizzle_32b
(
self
)
->
bool
:
return
self
==
SwizzleMode
.
SWIZZLE_32B
def
is_swizzle_64b
(
self
)
->
bool
:
return
self
==
SwizzleMode
.
SWIZZLE_64B
def
is_swizzle_128b
(
self
)
->
bool
:
return
self
==
SwizzleMode
.
SWIZZLE_128B
def
swizzle_byte_size
(
self
)
->
int
:
if
self
.
is_swizzle_32b
():
return
32
elif
self
.
is_swizzle_64b
():
return
64
elif
self
.
is_swizzle_128b
():
return
128
else
:
return
1
def
swizzle_atom_size
(
self
)
->
int
:
if
self
.
is_swizzle_32b
():
return
32
//
16
elif
self
.
is_swizzle_64b
():
return
64
//
16
elif
self
.
is_swizzle_128b
():
return
128
//
16
else
:
return
1
# derive from MMAIntrinEmitter as some layouts are the same
class
TensorCoreIntrinEmitter
(
MMAIntrinEmitter
):
"""
To eliminate Python syntax within TIR Macro.
"""
# should be rewritten to support dynamic k_dim
tcgen05_prefix
:
str
a_shared_layout
:
Layout
=
None
b_shared_layout
:
Layout
=
None
def
__init__
(
self
,
a_dtype
:
str
=
"float16"
,
b_dtype
:
str
=
"float16"
,
accum_dtype
:
str
=
"float16"
,
a_transposed
:
bool
=
False
,
b_transposed
:
bool
=
False
,
block_row_warps
:
int
=
2
,
block_col_warps
:
int
=
2
,
warp_row_tiles
:
int
=
8
,
warp_col_tiles
:
int
=
8
,
chunk
:
int
=
16
,
reduce_k
:
int
=
1
,
num_elems_per_byte
:
int
=
1
,
is_m_first
:
bool
=
False
,
thread_var
:
Var
|
None
=
None
,
):
super
().
__init__
(
a_dtype
,
b_dtype
,
accum_dtype
,
a_transposed
,
b_transposed
,
block_row_warps
,
block_col_warps
,
warp_row_tiles
,
warp_col_tiles
,
chunk
,
reduce_k
,
num_elems_per_byte
,
is_m_first
,
thread_var
)
def
_assign_a_shared_layout
(
self
,
layout
:
Layout
):
self
.
a_shared_layout
=
layout
return
self
def
_assign_b_shared_layout
(
self
,
layout
:
Layout
):
self
.
b_shared_layout
=
layout
return
self
def
_initialize_micro_size
(
self
,
m_dim
:
int
=
16
,
k_dim
:
int
=
16
):
warp_row_tiles
=
self
.
warp_row_tiles
warp_col_tiles
=
self
.
warp_col_tiles
# For tcgen05, warp_row_tiles is 8 as we can use .ws to support m32
assert
warp_row_tiles
>=
8
,
f
"warp_row_tiles must be greater than 8, got
{
warp_row_tiles
}
"
assert
warp_row_tiles
%
8
==
0
,
f
"warp_row_tiles must be divisible by 8, got
{
warp_row_tiles
}
"
assert
warp_col_tiles
>=
8
,
f
"warp_col_tiles must be greater than 8, got
{
warp_col_tiles
}
"
assert
warp_col_tiles
%
8
==
0
,
f
"warp_col_tiles must be divisible by 8, got
{
warp_col_tiles
}
"
# four warps per block
self
.
warp_rows
=
warp_row_tiles
//
8
if
warp_col_tiles
%
16
==
0
:
self
.
n_dim
=
16
self
.
micro_size_y
=
16
self
.
warp_cols
=
warp_col_tiles
//
16
else
:
# must be divisible by 8
self
.
n_dim
=
8
self
.
micro_size_y
=
8
self
.
warp_cols
=
warp_col_tiles
//
8
self
.
micro_size_x
=
m_dim
self
.
micro_size_k
=
k_dim
def
_determinate_swizzle_mode
(
self
,
buffer
:
Buffer
,
layout
:
Layout
)
->
SwizzleMode
:
# same behavior to src/layout/gemm_layouts.cc::makeGemmABLayoutHopper
if
layout
is
None
or
layout
.
is_equal
(
make_linear_layout
(
buffer
)):
return
SwizzleMode
.
NONE
elif
layout
.
is_equal
(
make_quarter_bank_swizzled_layout
(
buffer
)):
return
SwizzleMode
.
SWIZZLE_32B
elif
layout
.
is_equal
(
make_half_bank_swizzled_layout
(
buffer
)):
return
SwizzleMode
.
SWIZZLE_64B
elif
layout
.
is_equal
(
make_full_bank_swizzled_layout
(
buffer
)):
return
SwizzleMode
.
SWIZZLE_128B
else
:
raise
ValueError
(
f
"Unsupported swizzle mode:
{
layout
}
"
)
def
tcgen05mma
(
self
,
A_buf
:
Buffer
,
B_buf
:
Buffer
,
C_local_buf
:
Buffer
,
mbar
,
clear_accum
:
PrimExpr
=
False
):
if
is_tensor_memory
(
A_buf
):
return
self
.
tcgen05mma_rs
(
A_buf
,
B_buf
,
C_local_buf
,
clear_accum
)
accum_dtype
=
self
.
accum_dtype
m_dim
=
self
.
block_row_warps
*
self
.
warp_row_tiles
micro_size_k
=
self
.
micro_size_k
k_dim
,
n_dim
=
self
.
chunk
,
self
.
block_col_warps
*
self
.
warp_col_tiles
scale_in_a
=
1
scale_in_b
=
1
assert
k_dim
>=
micro_size_k
,
f
"k_dim must be greater than or equal to
{
micro_size_k
}
, got k_dim:
{
k_dim
}
"
a_is_k_major
=
not
self
.
a_transposed
b_is_k_major
=
self
.
b_transposed
a_swizzle_mode
=
self
.
_determinate_swizzle_mode
(
A_buf
,
self
.
a_shared_layout
)
b_swizzle_mode
=
self
.
_determinate_swizzle_mode
(
B_buf
,
self
.
b_shared_layout
)
elems_in_bits
=
DataType
(
self
.
a_dtype
).
bits
elems_in_bytes
=
elems_in_bits
//
8
a_swizzle_atom_elems
=
a_swizzle_mode
.
swizzle_byte_size
()
//
elems_in_bytes
b_swizzle_atom_elems
=
n_dim
if
b_swizzle_mode
.
is_none
(
)
else
b_swizzle_mode
.
swizzle_byte_size
()
//
elems_in_bytes
accum_dtype_in_bits
=
DataType
(
accum_dtype
).
bits
meta
=
self
.
get_tcgen5_mma_meta
(
m_dim
,
n_dim
,
k_dim
)
if
len
(
meta
)
!=
3
:
raise
ValueError
(
f
"Unsupported TCGEN5MMA configuration for desc generation: M=
{
m_dim
}
, N=
{
n_dim
}
, "
f
"K=
{
k_dim
}
, A dtype=
{
self
.
a_dtype
}
, accum dtype=
{
self
.
accum_dtype
}
"
)
atom_m
,
atom_n
,
atom_k
=
(
int
(
x
)
for
x
in
meta
)
enable_ws
=
atom_m
!=
128
# by default, we utilize non-swizzle layout offset
a_leading_byte_offset
=
(
8
*
8
*
elems_in_bytes
)
if
a_is_k_major
else
(
8
*
m_dim
*
elems_in_bytes
)
a_stride_byte_offset
=
(
8
*
k_dim
*
elems_in_bytes
)
if
a_is_k_major
else
(
8
*
8
*
elems_in_bytes
)
if
not
a_swizzle_mode
.
is_none
():
# swizzle mode doesn't require LBO/SBO to be 1
# https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-leading-dimension-byte-offset
if
a_is_k_major
:
a_leading_byte_offset
=
16
a_stride_byte_offset
=
8
*
a_swizzle_mode
.
swizzle_byte_size
()
else
:
# MN Major
# LBO represents the distance between two atoms along the M dimension
# SBO represents the distance between two atoms along the K dimension
a_m_axis_atoms
=
m_dim
//
a_swizzle_atom_elems
if
a_m_axis_atoms
<=
1
:
a_leading_byte_offset
=
0
else
:
a_leading_byte_offset
=
k_dim
*
a_swizzle_mode
.
swizzle_byte_size
()
if
a_m_axis_atoms
<=
1
:
a_stride_byte_offset
=
8
*
elems_in_bytes
*
m_dim
else
:
a_stride_byte_offset
=
8
*
elems_in_bytes
*
a_swizzle_atom_elems
b_leading_byte_offset
=
(
8
*
8
*
elems_in_bytes
)
if
b_is_k_major
else
(
8
*
n_dim
*
elems_in_bytes
)
b_stride_byte_offset
=
(
8
*
k_dim
*
elems_in_bytes
)
if
b_is_k_major
else
(
0
if
n_dim
==
8
else
(
8
*
8
*
elems_in_bytes
))
if
not
b_swizzle_mode
.
is_none
():
# swizzle mode doesn't require LBO/SBO to be 1
# https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-leading-dimension-byte-offset
if
b_is_k_major
:
b_leading_byte_offset
=
16
b_stride_byte_offset
=
8
*
b_swizzle_mode
.
swizzle_byte_size
()
else
:
# MN Major, K * N
# LBO represents the distance between two atoms along the N dimension
# SBO represents the distance between two atoms along the K dimension
b_n_axis_atoms
=
n_dim
//
b_swizzle_atom_elems
if
b_n_axis_atoms
<=
1
:
b_leading_byte_offset
=
0
else
:
b_leading_byte_offset
=
8
*
8
*
elems_in_bytes
*
k_dim
if
b_n_axis_atoms
<=
1
:
b_stride_byte_offset
=
8
*
elems_in_bytes
*
n_dim
else
:
b_stride_byte_offset
=
8
*
elems_in_bytes
*
b_swizzle_atom_elems
# for example, if [n, k] where k is 128, we should split it into 2 atoms
# where max specially handles the case when n_dim is 8.
ak_atom_size
=
max
(
a_swizzle_atom_elems
//
micro_size_k
,
1
)
bk_atom_size
=
max
(
b_swizzle_atom_elems
//
micro_size_k
,
1
)
instr_desc
=
self
.
get_tcgen5_instr_desc
(
atom_m
,
atom_n
,
atom_k
,
a_is_k_major
,
b_is_k_major
,
scale_in_a
,
scale_in_b
,
)
# Allocate an instruction descriptor wrapper and initialize it
a_dtype_abbrv
=
self
.
a_dtype_abbrv
mask_zero
=
T
.
Cast
(
"int32"
,
0
)
mask0
=
mask1
=
mask2
=
mask3
=
mask_zero
num_inst_m
=
4
*
self
.
warp_row_tiles
//
atom_m
num_inst_n
=
self
.
warp_col_tiles
//
atom_n
# Helper to allow BufferRegion/BufferLoad as inputs
def
access_ptr_from
(
buffer_or_load_or_region
,
access_type
:
str
=
"r"
):
if
isinstance
(
buffer_or_load_or_region
,
Buffer
):
return
buffer_or_load_or_region
.
access_ptr
(
access_type
)
elif
isinstance
(
buffer_or_load_or_region
,
BufferLoad
):
buffer_load
=
buffer_or_load_or_region
offset
,
stride
=
0
,
1
buffer
=
buffer_load
.
buffer
for
i
,
shape
in
enumerate
(
reversed
(
buffer
.
shape
)):
indice
=
buffer_load
.
indices
[
len
(
buffer_load
.
indices
)
-
i
-
1
]
if
isinstance
(
indice
,
(
tvm
.
tir
.
IntImm
,
tvm
.
tir
.
PrimExpr
)):
offset
+=
indice
*
stride
elif
isinstance
(
indice
,
tvm
.
tir
.
Ramp
):
offset
+=
indice
.
base
*
stride
else
:
raise
ValueError
(
f
"Unsupported index type:
{
type
(
indice
)
}
"
)
stride
*=
shape
return
buffer
.
access_ptr
(
access_type
,
offset
=
offset
)
elif
isinstance
(
buffer_or_load_or_region
,
BufferRegion
):
buffer_region
=
buffer_or_load_or_region
buffer
=
buffer_region
.
buffer
offset
,
stride
=
0
,
1
for
i
,
shape
in
enumerate
(
reversed
(
buffer
.
shape
)):
offset
+=
buffer_region
.
region
[
len
(
buffer_region
.
region
)
-
i
-
1
].
min
*
stride
stride
*=
shape
return
buffer
.
access_ptr
(
access_type
,
offset
=
offset
)
else
:
raise
ValueError
(
f
"Unsupported buffer type:
{
type
(
buffer_or_load_or_region
)
}
"
)
@
T
.
macro
def
_warp_mma
(
A_buf
,
B_buf
,
C_local_buf
,
mbar
):
# Allocate SMEM descriptors for A and B
desc_a
=
T
.
alloc_tcgen05_smem_desc
()
desc_b
=
T
.
alloc_tcgen05_smem_desc
()
A_ptr
=
access_ptr_from
(
A_buf
,
"r"
)
B_ptr
=
access_ptr_from
(
B_buf
,
"r"
)
T
.
initialize_tcgen05_descriptor
(
desc_a
,
A_ptr
,
int
(
a_leading_byte_offset
>>
4
),
int
(
a_stride_byte_offset
>>
4
),
0
,
False
,
int
(
a_swizzle_mode
),
)
T
.
initialize_tcgen05_descriptor
(
desc_b
,
B_ptr
,
int
(
b_leading_byte_offset
>>
4
),
int
(
b_stride_byte_offset
>>
4
),
0
,
False
,
int
(
b_swizzle_mode
),
)
tmem_col_step
=
atom_n
//
(
128
//
atom_m
)
for
j
in
T
.
unroll
(
num_inst_n
):
for
i
in
T
.
unroll
(
num_inst_m
):
for
ki
in
T
.
unroll
(
0
,
(
k_dim
//
micro_size_k
)):
scale_out
=
T
.
Select
(
ki
!=
0
,
1
,
T
.
Select
(
clear_accum
,
0
,
1
))
A_elem_offset
=
(
ki
%
ak_atom_size
)
*
micro_size_k
+
i
*
atom_m
*
a_swizzle_atom_elems
+
(
ki
//
ak_atom_size
)
*
m_dim
*
a_swizzle_atom_elems
if
a_is_k_major
else
i
*
atom_m
*
k_dim
+
ki
*
a_swizzle_atom_elems
*
micro_size_k
B_elem_offset
=
(
ki
//
bk_atom_size
)
*
n_dim
*
b_swizzle_atom_elems
+
(
ki
%
bk_atom_size
)
*
micro_size_k
+
j
*
atom_n
*
b_swizzle_atom_elems
if
b_is_k_major
else
(
ki
*
b_swizzle_atom_elems
*
micro_size_k
+
j
*
atom_n
*
(
k_dim
if
n_dim
//
b_swizzle_atom_elems
>
1
else
1
))
A_byte_offset
=
A_elem_offset
*
elems_in_bytes
B_byte_offset
=
B_elem_offset
*
elems_in_bytes
C_offset
=
(
i
*
n_dim
+
j
*
tmem_col_step
)
*
accum_dtype_in_bits
//
32
# 32 bits per tmem bank
T
.
ptx_tcgen05_mma_ss
(
a_dtype_abbrv
,
desc_a
.
data
,
A_byte_offset
,
desc_b
.
data
,
B_byte_offset
,
C_local_buf
.
data
,
C_offset
,
instr_desc
,
scale_out
,
mask0
,
mask1
,
mask2
,
mask3
,
enable_ws
,
)
T
.
tcgen05_mma_arrive
(
mbar
)
return
_warp_mma
(
A_buf
,
B_buf
,
C_local_buf
,
mbar
)
def
make_mma_load_layout
(
self
,
local_buf
:
Buffer
,
matrix
:
str
=
"A"
)
->
T
.
Fragment
:
raise
NotImplementedError
def
make_mma_store_layout
(
self
,
tmem_buf
:
Buffer
)
->
Layout
:
"""
Create the TCGEN5 tensor-memory layout used to store MMA accumulators.
Parameters
----------
tmem_buf : tir.Buffer
The local buffer representing tensormemory of a mma's output
Returns
-------
Layout
Layout object describing how logical (i, j) coordinates map to the
swizzled tensor-memory offsets required by TCGEN5MMA.
Raises
------
AssertionError
If `tmem_buf` is not detected to be a tensor-memory buffer.
"""
assert
is_tensor_memory
(
tmem_buf
),
"tmem_buf must reside in tensor memory (shared.tmem)"
if
len
(
tmem_buf
.
shape
)
!=
2
:
raise
ValueError
(
f
"TCGEN5MMA expects a 2-D tensor-memory buffer, got shape
{
tmem_buf
.
shape
}
"
)
m
=
int
(
tmem_buf
.
shape
[
0
])
n
=
int
(
tmem_buf
.
shape
[
1
])
k
=
int
(
self
.
chunk
)
meta
=
self
.
get_tcgen5_mma_meta
(
m
,
n
,
k
)
if
len
(
meta
)
!=
3
:
raise
ValueError
(
f
"Unsupported TCGEN5MMA configuration: M=
{
m
}
, N=
{
n
}
, K=
{
k
}
, "
f
"A dtype=
{
self
.
a_dtype
}
, accum dtype=
{
self
.
accum_dtype
}
"
)
atom_m
,
atom_n
,
_
=
(
int
(
x
)
for
x
in
meta
)
if
m
%
atom_m
!=
0
or
n
%
atom_n
!=
0
:
raise
ValueError
(
f
"Invalid TCGEN5MMA store layout for shape (
{
m
}
,
{
n
}
) with atoms (
{
atom_m
}
,
{
atom_n
}
)"
)
def
forward
(
i
:
PrimExpr
,
j
:
PrimExpr
):
atom_idx
=
(
i
//
atom_m
)
+
(
j
//
atom_n
)
*
(
m
//
atom_m
)
ai
=
i
%
atom_m
aj
=
j
%
atom_n
if
atom_m
==
128
:
# Layout D
return
[
ai
,
aj
+
atom_idx
*
atom_n
,
]
if
atom_m
==
64
:
# Layout E (.ws variant)
half_atom_n
=
atom_n
//
2
return
[
(
ai
//
32
)
*
32
+
ai
%
32
+
(
aj
//
half_atom_n
)
*
64
,
(
aj
%
half_atom_n
)
+
atom_idx
*
half_atom_n
,
]
if
atom_m
==
32
:
# Layout G
quarter_atom_n
=
atom_n
//
4
return
[
ai
%
32
+
(
aj
//
quarter_atom_n
)
*
32
,
(
aj
%
quarter_atom_n
)
+
atom_idx
*
quarter_atom_n
,
]
raise
ValueError
(
f
"Unsupported TCGEN5 atom_m=
{
atom_m
}
"
)
return
Layout
([
m
,
n
],
forward
)
def
get_tcgen5_mma_meta
(
self
,
m
:
int
,
n
:
int
,
k
:
int
):
return
_ffi_api
.
get_tcgen5_mma_meta
(
int
(
m
),
int
(
n
),
int
(
k
),
DataType
(
self
.
a_dtype
),
DataType
(
self
.
accum_dtype
))
def
get_tcgen5_instr_desc
(
self
,
atom_m
:
int
,
atom_n
:
int
,
atom_k
:
int
,
a_is_k_major
:
bool
,
b_is_k_major
:
bool
,
scale_in_a
:
int
,
scale_in_b
:
int
)
->
PrimExpr
:
desc
=
_ffi_api
.
get_tcgen5_instr_desc
(
atom_m
,
atom_n
,
atom_k
,
DataType
(
self
.
a_dtype
),
DataType
(
self
.
accum_dtype
),
a_is_k_major
,
b_is_k_major
,
scale_in_a
,
scale_in_b
,
)
return
lift
(
desc
)
tilelang/intrinsics/utils.py
View file @
bbbf4207
...
...
@@ -8,6 +8,7 @@ from .mma_layout import (
ldmatrix_32x16_to_shared_16x32_layout_a
,
ldmatrix_32x16_to_shared_16x32_layout_b
,
mma_store_32x8_to_shared_16x16_layout
,
mma_store_32x2_to_shared_8x8_layout_fp64
,
)
from
.mfma_layout
import
(
thread_id_shared_access_64x4_to_16x16_layout_C_n_m
)
...
...
@@ -82,6 +83,10 @@ def mma_store_index_map(thread_id, local_id):
return
mma_store_32x8_to_shared_16x16_layout
(
thread_id
,
local_id
)
def
mma_store_index_map_fp64
(
thread_id
,
local_id
):
return
mma_store_32x2_to_shared_8x8_layout_fp64
(
thread_id
,
local_id
)
def
mfma_store_index_map
(
thread_id
,
local_id
):
return
thread_id_shared_access_64x4_to_16x16_layout_C_n_m
(
thread_id
,
local_id
)
...
...
tilelang/intrinsics/wgmma_macro_generator.py
View file @
bbbf4207
...
...
@@ -4,8 +4,9 @@ from enum import IntEnum
from
typing
import
Callable
from
.mma_macro_generator
import
TensorCoreIntrinEmitter
as
MMAIntrinEmitter
from
tvm
import
DataType
from
tvm.tir
import
PrimExpr
,
Buffer
,
Var
,
IndexMap
from
tilelang.utils
import
is_fragment
from
tvm.tir
import
PrimExpr
,
Buffer
,
Var
,
IndexMap
,
BufferRegion
from
tilelang.utils
import
is_fragment
,
retrive_ptr_from_buffer_region
,
is_full_region
from
math
import
gcd
from
tilelang.layout
import
(
Layout
,
make_full_bank_swizzled_layout
,
...
...
@@ -70,6 +71,11 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
# should be rewritten to support dynamic k_dim
wgmma_prefix
:
str
# wgmma instruction M dimension
wgmma_inst_m
:
int
# wgmma instruction N dimension
wgmma_inst_n
:
int
a_shared_layout
:
Layout
=
None
b_shared_layout
:
Layout
=
None
...
...
@@ -104,9 +110,18 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
return
self
def
_initialize_wgmma_prefix
(
self
,
n_dim
:
int
=
16
):
inst_m
,
inst_n
=
64
,
self
.
block_col_warps
*
self
.
warp_col_tiles
inst_m
,
inst_n
=
64
,
gcd
(
self
.
warp_col_tiles
,
256
)
assert
inst_n
%
8
==
0
,
(
f
"inst_n must be a multiple of 8, got
{
inst_n
}
"
f
"(block_col_warps=
{
self
.
block_col_warps
}
, warp_col_tiles=
{
self
.
warp_col_tiles
}
)"
)
# Validate inst_n: Hopper WGMMA supports n in [8, 256] and multiple of 8
assert
8
<=
inst_n
<=
256
,
(
f
"inst_n must be within [8, 256], got
{
inst_n
}
"
f
"(block_col_warps=
{
self
.
block_col_warps
}
, warp_col_tiles=
{
self
.
warp_col_tiles
}
)"
)
# 256 bits per instruction
inst_k
=
256
//
DataType
(
self
.
a_dtype
).
bits
self
.
wgmma_inst_m
=
inst_m
self
.
wgmma_inst_n
=
inst_n
self
.
wgmma_prefix
=
f
"m
{
inst_m
}
n
{
inst_n
}
k
{
inst_k
}
"
def
_initialize_micro_size
(
self
,
m_dim
:
int
=
16
,
k_dim
:
int
=
16
):
...
...
@@ -146,13 +161,14 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
raise
ValueError
(
f
"Unsupported swizzle mode:
{
layout
}
"
)
def
wgmma
(
self
,
A_buf
:
Buffer
,
B_buf
:
Buffer
,
C_local_buf
:
Buffer
,
clear_accum
:
PrimExpr
=
False
):
A_region
:
BufferRegion
,
B_region
:
BufferRegion
,
C_region
:
BufferRegion
,
clear_accum
:
PrimExpr
=
False
,
wg_wait
:
int
=
0
):
if
is_fragment
(
A_
buf
):
return
self
.
wgmma_rs
(
A_
buf
,
B_buf
,
C_local_buf
,
clear_accum
)
if
is_fragment
(
A_
region
):
return
self
.
wgmma_rs
(
A_
region
,
B_region
,
C_region
,
clear_accum
,
wg_wait
)
local_size_out
=
self
.
local_size_out
a_dtype_abbrv
=
self
.
a_dtype_abbrv
...
...
@@ -164,7 +180,6 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
micro_size_k
=
self
.
micro_size_k
k_dim
,
n_dim
=
self
.
chunk
,
self
.
block_col_warps
*
self
.
warp_col_tiles
wgmma_prefix
=
self
.
wgmma_prefix
scale_out
=
not
clear_accum
scale_in_a
=
1
scale_in_b
=
1
...
...
@@ -173,8 +188,8 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
a_is_k_major
=
not
self
.
a_transposed
b_is_k_major
=
self
.
b_transposed
a_swizzle_mode
=
self
.
_determinate_swizzle_mode
(
A_
buf
,
self
.
a_shared_layout
)
b_swizzle_mode
=
self
.
_determinate_swizzle_mode
(
B_
buf
,
self
.
b_shared_layout
)
a_swizzle_mode
=
self
.
_determinate_swizzle_mode
(
A_
region
,
self
.
a_shared_layout
)
b_swizzle_mode
=
self
.
_determinate_swizzle_mode
(
B_
region
,
self
.
b_shared_layout
)
elems_in_bits
=
DataType
(
self
.
a_dtype
).
bits
elems_in_bytes
=
elems_in_bits
//
8
...
...
@@ -182,6 +197,8 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
a_swizzle_atom_elems
=
a_swizzle_mode
.
swizzle_byte_size
()
//
elems_in_bytes
b_swizzle_atom_elems
=
n_dim
if
b_swizzle_mode
.
is_none
(
)
else
b_swizzle_mode
.
swizzle_byte_size
()
//
elems_in_bytes
accum_bits
=
DataType
(
accum_dtype
).
bits
accum_regs
=
((
m_dim
//
64
)
*
warp_cols
*
local_size_out
*
accum_bits
+
31
)
//
32
# by default, we utilize non-swizzle layout offset
a_leading_byte_offset
=
(
8
*
8
*
elems_in_bytes
)
if
a_is_k_major
else
(
8
*
m_dim
*
...
...
@@ -240,41 +257,69 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
# where max specially handles the case when n_dim is 8.
ak_atom_size
=
max
(
a_swizzle_atom_elems
//
micro_size_k
,
1
)
bk_atom_size
=
max
(
b_swizzle_atom_elems
//
micro_size_k
,
1
)
wgmma_inst_m
,
wgmma_inst_n
=
self
.
wgmma_inst_m
,
self
.
wgmma_inst_n
num_inst_m
=
4
*
self
.
warp_row_tiles
//
wgmma_inst_m
num_inst_n
=
self
.
warp_col_tiles
//
wgmma_inst_n
thread_binding
=
self
.
get_thread_binding
()
A_ptr
=
retrive_ptr_from_buffer_region
(
A_region
)
B_ptr
=
retrive_ptr_from_buffer_region
(
B_region
)
assert
is_full_region
(
C_region
),
"Fragment output C must be a full region"
C_buf
=
C_region
.
buffer
@
T
.
macro
def
_warp_mma
(
A_buf
,
B_buf
,
C_local_buf
):
# TODO(lei): inject warpgroup_fence_operand for C_local_buf
desc_a
=
T
.
alloc_descriptor
()
desc_b
=
T
.
alloc_descriptor
()
T
.
initialize_descriptor
(
desc_a
,
A_buf
.
access_ptr
(
"r"
),
a_swizzle_mode
,
int
(
a_leading_byte_offset
>>
4
),
int
(
a_stride_byte_offset
>>
4
))
T
.
initialize_descriptor
(
desc_b
,
B_buf
.
access_ptr
(
"r"
),
b_swizzle_mode
,
int
(
b_leading_byte_offset
>>
4
),
int
(
b_stride_byte_offset
>>
4
))
def
_warp_mma
(
A_ptr
,
B_ptr
,
C_buf
):
tx
,
warp_n
,
warp_m
=
self
.
extract_thread_binding
(
thread_binding
)
desc_a
=
T
.
alloc_wgmma_desc
()
desc_b
=
T
.
alloc_wgmma_desc
()
T
.
initialize_wgmma_descriptor
(
desc_a
,
A_ptr
,
a_swizzle_mode
,
int
(
a_leading_byte_offset
>>
4
),
int
(
a_stride_byte_offset
>>
4
))
T
.
initialize_wgmma_descriptor
(
desc_b
,
B_ptr
,
b_swizzle_mode
,
int
(
b_leading_byte_offset
>>
4
),
int
(
b_stride_byte_offset
>>
4
))
T
.
warpgroup_fence_operand
(
C_buf
,
num_regs
=
accum_regs
)
T
.
warpgroup_arrive
()
for
ki
in
T
.
serial
(
0
,
(
k_dim
//
micro_size_k
)):
for
i
in
T
.
serial
(
m_dim
//
64
):
A_offset
=
(
ki
%
ak_atom_size
)
*
micro_size_k
+
i
*
64
*
a_swizzle_atom_elems
+
(
for
j
in
T
.
unroll
(
num_inst_n
):
for
i
in
T
.
unroll
(
num_inst_m
):
for
ki
in
T
.
unroll
(
k_dim
//
micro_size_k
):
scale_out
=
T
.
Select
(
ki
!=
0
,
1
,
T
.
Select
(
clear_accum
,
0
,
1
))
warp_i
=
(
warp_m
//
4
)
*
num_inst_m
+
i
warp_j
=
warp_n
*
num_inst_n
+
j
A_offset
=
(
ki
%
ak_atom_size
)
*
micro_size_k
+
warp_i
*
64
*
a_swizzle_atom_elems
+
(
ki
//
ak_atom_size
)
*
m_dim
*
a_swizzle_atom_elems
if
a_is_k_major
else
i
*
64
*
k_dim
+
ki
*
a_swizzle_atom_elems
*
micro_size_k
)
*
m_dim
*
a_swizzle_atom_elems
if
a_is_k_major
else
warp_
i
*
64
*
k_dim
+
ki
*
a_swizzle_atom_elems
*
micro_size_k
B_offset
=
(
ki
//
bk_atom_size
)
*
n_dim
*
b_swizzle_atom_elems
+
(
ki
%
bk_atom_size
)
*
micro_size_k
if
b_is_k_major
else
ki
*
b_swizzle_atom_elems
*
micro_size_k
C_offset
=
i
*
warp_cols
*
local_size_out
# 4 warps as an unit
)
*
micro_size_k
+
warp_j
*
wgmma_inst_n
*
b_swizzle_atom_elems
if
b_is_k_major
else
(
ki
*
b_swizzle_atom_elems
*
micro_size_k
+
warp_j
*
wgmma_inst_n
*
(
k_dim
if
n_dim
//
b_swizzle_atom_elems
>
1
else
1
))
C_offset
=
i
*
warp_cols
*
local_size_out
+
j
*
warp_cols
*
local_size_out
//
num_inst_n
# 4 warps as an unit
T
.
ptx_wgmma_ss
(
accum_dtype
,
wgmma_prefix
,
a_is_k_major
,
b_is_k_major
,
a_dtype_abbrv
,
b_dtype_abbrv
,
accum_dtype_abbrv
,
desc_a
.
data
,
(
A_offset
*
elems_in_bytes
)
>>
4
,
desc_b
.
data
,
(
B_offset
*
elems_in_bytes
)
>>
4
,
C_
local_
buf
.
data
,
C_offset
,
(
B_offset
*
elems_in_bytes
)
>>
4
,
C_buf
.
data
,
C_offset
,
scale_out
,
scale_in_a
,
scale_in_b
)
T
.
warpgroup_commit_batch
()
T
.
warpgroup_wait
(
0
)
if
wg_wait
>=
0
:
T
.
warpgroup_wait
(
wg_wait
)
T
.
warpgroup_fence_operand
(
C_buf
,
num_regs
=
accum_regs
)
return
_warp_mma
(
A_
buf
,
B_
buf
,
C_local
_buf
)
return
_warp_mma
(
A_
ptr
,
B_
ptr
,
C
_buf
)
def
wgmma_rs
(
self
,
A_buf
:
Buffer
,
B_buf
:
Buffer
,
C_local_buf
:
Buffer
,
clear_accum
:
PrimExpr
=
False
):
A_region
:
BufferRegion
,
B_region
:
BufferRegion
,
C_region
:
BufferRegion
,
clear_accum
:
PrimExpr
=
False
,
wg_wait
:
int
=
0
):
local_size_a
=
self
.
local_size_a
local_size_out
=
self
.
local_size_out
a_dtype_abbrv
=
self
.
a_dtype_abbrv
...
...
@@ -286,60 +331,90 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
micro_size_k
=
self
.
micro_size_k
k_dim
,
n_dim
=
self
.
chunk
,
self
.
block_col_warps
*
self
.
warp_col_tiles
wgmma_prefix
=
self
.
wgmma_prefix
scale_out
=
not
clear_accum
scale_in_a
=
1
scale_in_b
=
1
assert
k_dim
>=
micro_size_k
,
f
"k_dim must be greater than or equal to
{
micro_size_k
}
, got k_dim:
{
k_dim
}
"
elems_in_bytes
=
DataType
(
self
.
a_dtype
).
bits
//
8
a_bits
=
DataType
(
self
.
a_dtype
).
bits
accum_bits
=
DataType
(
accum_dtype
).
bits
a_regs
=
((
warp_rows
*
local_size_a
*
(
k_dim
//
micro_size_k
))
*
a_bits
+
31
)
//
32
accum_regs
=
((
m_dim
//
64
)
*
warp_cols
*
local_size_out
*
accum_bits
+
31
)
//
32
b_is_k_major
=
self
.
b_transposed
b_swizzle_mode
=
self
.
_determinate_swizzle_mode
(
B_buf
,
self
.
b_shared_layout
)
b_swizzle_mode
=
self
.
_determinate_swizzle_mode
(
B_region
,
self
.
b_shared_layout
)
b_swizzle_atom_elems
=
n_dim
if
b_swizzle_mode
.
is_none
(
)
else
b_swizzle_mode
.
swizzle_byte_size
()
//
elems_in_bytes
b_leading_byte_offset
=
(
8
*
8
*
elems_in_bytes
)
if
b_is_k_major
else
(
8
*
n_dim
*
elems_in_bytes
)
b_stride_byte_offset
=
(
8
*
k_dim
*
elems_in_bytes
)
if
b_is_k_major
else
(
8
*
8
*
elems_in_bytes
)
b_stride_byte_offset
=
(
8
*
k_dim
*
elems_in_bytes
)
if
b_is_k_major
else
(
0
if
n_dim
==
8
else
(
8
*
8
*
elems_in_bytes
))
if
not
b_swizzle_mode
.
is_none
():
# swizzle mode doesn't require LBO/SBO to be 1
# https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-leading-dimension-byte-offset
if
b_is_k_major
:
b_leading_byte_offset
=
16
b_stride_byte_offset
=
8
*
b_swizzle_mode
.
swizzle_byte_size
()
else
:
# MN Major
# LBO represents the distance between two atoms along the N dimension
# SBO represents the distance between two atoms along the K dimension
b_n_axis_atoms
=
n_dim
//
(
b_swizzle_
mode
.
swizzle_byte_size
()
//
elems_in_bytes
)
b_n_axis_atoms
=
n_dim
//
b_swizzle_
atom_elems
if
b_n_axis_atoms
<=
1
:
b_leading_byte_offset
=
0
else
:
b_leading_byte_offset
=
8
*
b_swizzle_mode
.
swizzle_atom_size
()
*
(
b_swizzle_mode
.
swizzle_byte_size
()
//
elems_in_bytes
)
b_leading_byte_offset
=
8
*
8
*
elems_in_bytes
*
k_dim
if
b_n_axis_atoms
<=
1
:
b_stride_byte_offset
=
8
*
elems_in_bytes
*
n_dim
else
:
b_stride_byte_offset
=
8
*
elems_in_bytes
*
(
b_swizzle_mode
.
swizzle_byte_size
()
//
elems_in_bytes
)
b_stride_byte_offset
=
8
*
elems_in_bytes
*
b_swizzle_atom_elems
bk_atom_size
=
max
(
b_swizzle_atom_elems
//
micro_size_k
,
1
)
wgmma_inst_m
,
wgmma_inst_n
=
self
.
wgmma_inst_m
,
self
.
wgmma_inst_n
num_inst_m
=
4
*
self
.
warp_row_tiles
//
wgmma_inst_m
num_inst_n
=
self
.
warp_col_tiles
//
wgmma_inst_n
thread_binding
=
self
.
get_thread_binding
()
assert
is_full_region
(
A_region
),
"Fragment input A must be a full region"
assert
is_full_region
(
C_region
),
"Fragment output C must be a full region"
A_buf
=
A_region
.
buffer
B_ptr
=
retrive_ptr_from_buffer_region
(
B_region
)
C_buf
=
C_region
.
buffer
@
T
.
macro
def
_warp_mma
(
A_buf
,
B_buf
,
C_local_buf
):
desc_b
=
T
.
alloc_descriptor
()
T
.
initialize_descriptor
(
desc_b
,
B_buf
.
access_ptr
(
"w"
),
b_swizzle_mode
,
int
(
b_leading_byte_offset
>>
4
),
int
(
b_stride_byte_offset
>>
4
))
for
ki
in
T
.
serial
(
0
,
(
k_dim
//
micro_size_k
)):
for
i
in
T
.
serial
(
m_dim
//
64
):
k_dim_offset
=
ki
*
micro_size_k
def
_warp_mma
(
A_buf
,
B_ptr
,
C_buf
):
tx
,
warp_n
,
warp_m
=
self
.
extract_thread_binding
(
thread_binding
)
desc_b
=
T
.
alloc_wgmma_desc
()
T
.
initialize_wgmma_descriptor
(
desc_b
,
B_ptr
,
b_swizzle_mode
,
int
(
b_leading_byte_offset
>>
4
),
int
(
b_stride_byte_offset
>>
4
))
T
.
warpgroup_fence_operand
(
A_buf
,
num_regs
=
a_regs
)
T
.
warpgroup_fence_operand
(
C_buf
,
num_regs
=
accum_regs
)
T
.
warpgroup_arrive
()
for
j
in
T
.
unroll
(
0
,
num_inst_n
):
for
i
in
T
.
unroll
(
num_inst_m
):
for
ki
in
T
.
unroll
(
0
,
(
k_dim
//
micro_size_k
)):
warp_j
=
warp_n
*
num_inst_n
+
j
scale_out
=
T
.
Select
(
ki
!=
0
,
1
,
T
.
Select
(
clear_accum
,
0
,
1
))
A_offset
=
ki
*
warp_rows
*
local_size_a
+
i
*
local_size_a
B_offset
=
k_dim_offset
if
b_is_k_major
else
k_dim_offset
*
B_buf
.
shape
[
-
1
]
C_offset
=
i
*
warp_cols
*
local_size_out
# 4 warps as an unit
B_offset
=
(
ki
//
bk_atom_size
)
*
n_dim
*
b_swizzle_atom_elems
+
warp_j
*
wgmma_inst_n
*
b_swizzle_atom_elems
+
(
ki
%
bk_atom_size
)
*
micro_size_k
if
b_is_k_major
else
(
ki
*
b_swizzle_atom_elems
*
micro_size_k
+
warp_j
*
wgmma_inst_n
*
(
k_dim
if
n_dim
//
b_swizzle_atom_elems
>
1
else
1
))
C_offset
=
i
*
warp_cols
*
local_size_out
+
j
*
warp_cols
*
local_size_out
//
num_inst_n
# 4 warps as an unit
T
.
ptx_wgmma_rs
(
accum_dtype
,
wgmma_prefix
,
self
.
a_transposed
,
not
self
.
b_transposed
,
self
.
b_transposed
,
a_dtype_abbrv
,
b_dtype_abbrv
,
accum_dtype_abbrv
,
...
...
@@ -347,14 +422,20 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
A_offset
,
desc_b
.
data
,
(
B_offset
*
elems_in_bytes
)
>>
4
,
C_local
_buf
.
data
,
C
_buf
.
data
,
C_offset
,
scale_out
,
scale_in_a
,
scale_in_b
,
)
return
_warp_mma
(
A_buf
,
B_buf
,
C_local_buf
)
T
.
warpgroup_commit_batch
()
if
wg_wait
>=
0
:
T
.
warpgroup_wait
(
wg_wait
)
T
.
warpgroup_fence_operand
(
C_buf
,
num_regs
=
accum_regs
)
T
.
warpgroup_fence_operand
(
A_buf
,
num_regs
=
a_regs
)
return
_warp_mma
(
A_buf
,
B_ptr
,
C_buf
)
def
make_mma_load_layout
(
self
,
local_buf
:
Buffer
,
matrix
:
str
=
"A"
)
->
T
.
Fragment
:
"""
...
...
Prev
1
…
8
9
10
11
12
13
14
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment