Unverified Commit bbbf4207 authored by guchaoyang's avatar guchaoyang Committed by GitHub
Browse files

Merge branch 'main' into dcu

parents 8f4628e0 5eb30a4f
...@@ -4,17 +4,24 @@ This module provides functionality for auto-tuning tilelang programs, including ...@@ -4,17 +4,24 @@ This module provides functionality for auto-tuning tilelang programs, including
and performance optimization through configuration search. and performance optimization through configuration search.
""" """
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass
import tilelang import tilelang
from tilelang import tvm as tvm from tilelang import tvm as tvm
from tilelang.jit import JITImpl
from tilelang.jit.kernel import JITKernel
from tvm.tir import PrimFunc, Var from tvm.tir import PrimFunc, Var
from tvm.target import Target from tvm.target import Target
import inspect import inspect
from functools import partial from functools import partial
from typing import (Callable, Literal, Any, overload) from typing import (Callable, Generic, Literal, Any, TypeVar)
from tqdm import tqdm # Python 3.9 compatibility for ParamSpec
try:
from typing import ParamSpec
except ImportError: # Python < 3.10
from typing_extensions import ParamSpec
from tqdm.auto import tqdm
import logging import logging
import functools
import concurrent.futures import concurrent.futures
import torch import torch
import os import os
...@@ -30,7 +37,6 @@ from tilelang import env ...@@ -30,7 +37,6 @@ from tilelang import env
from tilelang.autotuner.param import CompileArgs, ProfileArgs, AutotuneResult from tilelang.autotuner.param import CompileArgs, ProfileArgs, AutotuneResult
from tilelang.autotuner.capture import get_autotune_inputs from tilelang.autotuner.capture import get_autotune_inputs
from tilelang.utils.target import determine_target from tilelang.utils.target import determine_target
from tilelang.jit.param import _P, _RProg
from tilelang import __version__ from tilelang import __version__
...@@ -524,12 +530,12 @@ class AutoTuner: ...@@ -524,12 +530,12 @@ class AutoTuner:
# latency, ref_latency = target_fn(jit_kernel) # latency, ref_latency = target_fn(jit_kernel)
latency, ref_latency = run_with_timeout(target_fn, timeout, jit_kernel) latency, ref_latency = run_with_timeout(target_fn, timeout, jit_kernel)
except TimeoutException: except TimeoutException:
logger.info( logger.warning(
f"A timeout occurred while testing config {config}, checkout autotuner.log for more details" f"A timeout occurred while testing config {config}, checkout autotuner.log for more details"
) )
continue continue
except Exception: except Exception:
logger.info( logger.warning(
f"An error occurred while testing config {config}, checkout autotuner.log for more details" f"An error occurred while testing config {config}, checkout autotuner.log for more details"
) )
logger.debug(f"Error: {traceback.format_exc()}") logger.debug(f"Error: {traceback.format_exc()}")
...@@ -585,9 +591,13 @@ class AutoTuner: ...@@ -585,9 +591,13 @@ class AutoTuner:
return self.run() return self.run()
class _AutoTunerImplementation: _P = ParamSpec('_P')
# Overload __init__ to help type checkers understand the effect of return_program _T = TypeVar('_T')
# The '-> None' is for __init__ itself. The crucial part is Literal for return_program.
@dataclass
class AutoTuneImpl(Generic[_P, _T]):
jit_impl: JITImpl
warmup: int = 25 warmup: int = 25
rep: int = 100 rep: int = 100
...@@ -603,125 +613,51 @@ class _AutoTunerImplementation: ...@@ -603,125 +613,51 @@ class _AutoTunerImplementation:
manual_check_prog: Callable = None manual_check_prog: Callable = None
cache_input_tensors: bool = False cache_input_tensors: bool = False
def __init__(self, def __post_init__(self):
configs: dict | Callable, self._tuner_cache = {}
warmup: int = 25,
rep: int = 100, def get_tunner(self):
timeout: int = 100, autotuner = AutoTuner(
supply_type: tilelang.TensorSupplyType = tilelang.TensorSupplyType.Auto, self.jit_impl.func, configs=self.configs).set_profile_args(
ref_prog: Callable = None, supply_type=self.supply_type,
supply_prog: Callable = None, ref_prog=self.ref_prog,
rtol: float = 1e-2, supply_prog=self.supply_prog,
atol: float = 1e-2, rtol=self.rtol,
max_mismatched_ratio: float = 0.01, atol=self.atol,
skip_check: bool = False, max_mismatched_ratio=self.max_mismatched_ratio,
manual_check_prog: Callable = None, skip_check=self.skip_check,
cache_input_tensors: bool = False) -> None: manual_check_prog=self.manual_check_prog,
"""Initialize the AutoTunerImplementation. cache_input_tensors=self.cache_input_tensors,
).set_compile_args(
out_idx=self.jit_impl.out_idx,
execution_backend=self.jit_impl.execution_backend,
target=self.jit_impl.target,
target_host=self.jit_impl.target_host,
verbose=self.jit_impl.verbose,
pass_configs=self.jit_impl.pass_configs,
)
autotuner.run = partial(autotuner.run, self.warmup, self.rep, self.timeout)
return autotuner
Args: def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> JITKernel:
configs: Configuration space to explore during auto-tuning. key_args_tuple = args
warmup: Number of warmup iterations before timing. key_kwargs_tuple = tuple(sorted(kwargs.items()))
rep: Number of repetitions for timing measurements. key = (key_args_tuple, key_kwargs_tuple)
timeout: Maximum time (in seconds) allowed for each configuration. if key not in self._tuner_cache:
supply_type: Strategy for generating input tensors (random/zeros/etc)
ref_prog: Reference implementation for validation def jit_compile(**config_arg):
supply_prog: Custom function to provide input tensors return self.jit_impl(*args, **kwargs, __tune_params=config_arg)
rtol: Relative tolerance for numerical validation
atol: Absolute tolerance for numerical validation autotuner = self.get_tunner()
max_mismatched_ratio: Allowed percentage of mismatched values autotuner.jit_compile = jit_compile
skip_check: Bypass validation against reference implementation autotuner.set_kernel_parameters(key, self.jit_impl.signature.parameters)
manual_check_prog: Custom validation function artifact = autotuner.run()
cache_input_tensors: Reuse input tensors across trials self._tuner_cache[key] = artifact.kernel
""" return self._tuner_cache[key]
# Configuration and benchmarking parameters
self.configs = configs # Search space of tuning configurations
self.warmup = warmup # Warmup iterations for stable measurements
self.rep = rep # Measurement repetitions for statistics
self.timeout = timeout # Per-configuration timeout threshold
# Tensor handling and validation setup
self.supply_type = supply_type # Input tensor generation strategy
self.ref_prog = ref_prog # Ground truth implementation
self.supply_prog = supply_prog # Custom input data provider
self.rtol = rtol # Relative error tolerance
self.atol = atol # Absolute error tolerance
self.max_mismatched_ratio = max_mismatched_ratio # Allowed mismatch
# Validation control flags
self.skip_check = skip_check # Bypass accuracy verification
self.manual_check_prog = manual_check_prog # Custom validation
self.cache_input_tensors = cache_input_tensors # Reuse inputs
# Cache for storing tuned kernel implementations
self._tuner_cache: dict[tuple, tilelang.JITKernel] = {} # (args, kwargs) -> compiled kernel
# This tells the type checker what the *wrapper* function will return.
# this is for linting, please do not remove it.
@overload
def __call__(self, fn: Callable[_P, _RProg]) -> Callable[_P, tuple[_RProg, AutotuneResult]]:
...
@overload
def __call__(self, fn: Callable[_P, _RProg]) -> Callable[_P, AutotuneResult]:
...
# Actual implementation of __call__
def __call__(self, fn: Callable[_P, _RProg]) -> Callable[_P, Any]:
warmup = self.warmup
rep = self.rep
timeout = self.timeout
@functools.wraps(fn)
def wrapper(*args, **kwargs):
key_args_tuple = args
key_kwargs_tuple = tuple(sorted(kwargs.items()))
key = (key_args_tuple, key_kwargs_tuple)
if key not in self._tuner_cache:
def jit_compile(**config_arg):
return fn(*args, **kwargs, __tune_params=config_arg)
compile_arguments = fn(__return_compile_arguments=True)
autotuner = AutoTuner(
fn, configs=self.configs).set_profile_args(
supply_type=self.supply_type,
ref_prog=self.ref_prog,
supply_prog=self.supply_prog,
rtol=self.rtol,
atol=self.atol,
max_mismatched_ratio=self.max_mismatched_ratio,
skip_check=self.skip_check,
manual_check_prog=self.manual_check_prog,
cache_input_tensors=self.cache_input_tensors,
).set_compile_args(
out_idx=compile_arguments['out_idx'],
execution_backend=compile_arguments['execution_backend'],
target=compile_arguments['target'],
target_host=compile_arguments['target_host'],
verbose=compile_arguments['verbose'],
pass_configs=compile_arguments['pass_configs'],
)
autotuner.jit_compile = jit_compile
autotuner.set_kernel_parameters(key, inspect.signature(fn).parameters)
autotuner.run = partial(autotuner.run, warmup, rep, timeout)
artifact = autotuner.run()
self._tuner_cache[key] = artifact.kernel
return self._tuner_cache[key]
return wrapper
def autotune( # This is the new public interface def autotune( # This is the new public interface
func: Callable[_P, _RProg] | PrimFunc | None = None, func: Callable[_P, _T] | PrimFunc | None = None,
*, # Indicates subsequent arguments are keyword-only *, # Indicates subsequent arguments are keyword-only
configs: dict | Callable, configs: dict | Callable,
# profile arguments # profile arguments
...@@ -795,22 +731,26 @@ def autotune( # This is the new public interface ...@@ -795,22 +731,26 @@ def autotune( # This is the new public interface
elif isinstance(func, PrimFunc): elif isinstance(func, PrimFunc):
raise ValueError("Use tilelang.jit to decorate prim_func is not supported yet.") raise ValueError("Use tilelang.jit to decorate prim_func is not supported yet.")
else: else:
# Case 2: Used as @autotune(...) to configure, or func_or_out_idx is meant as out_idx.
# Create a _AutoTunerImplementation instance with the provided/defaulted arguments. def decorator(impl):
# This instance is a decorator that will be applied to the function later. assert isinstance(
configured_decorator = _AutoTunerImplementation( impl, JITImpl
configs=configs, ), "The @autotune decorator can only be applied to @tilelang.jit decorated instances."
warmup=warmup, return AutoTuneImpl(
rep=rep, jit_impl=impl,
timeout=timeout, configs=configs,
supply_type=supply_type, warmup=warmup,
ref_prog=ref_prog, rep=rep,
supply_prog=supply_prog, timeout=timeout,
rtol=rtol, supply_type=supply_type,
atol=atol, ref_prog=ref_prog,
max_mismatched_ratio=max_mismatched_ratio, supply_prog=supply_prog,
skip_check=skip_check, rtol=rtol,
manual_check_prog=manual_check_prog, atol=atol,
cache_input_tensors=cache_input_tensors, max_mismatched_ratio=max_mismatched_ratio,
) skip_check=skip_check,
return configured_decorator manual_check_prog=manual_check_prog,
cache_input_tensors=cache_input_tensors,
)
return decorator
...@@ -3,7 +3,7 @@ from __future__ import annotations ...@@ -3,7 +3,7 @@ from __future__ import annotations
import functools import functools
import math import math
from queue import PriorityQueue from queue import PriorityQueue
from typing import Iterable from collections.abc import Iterable
import numpy as np import numpy as np
import tvm import tvm
......
from __future__ import annotations from __future__ import annotations
from typing import Mapping from collections.abc import Mapping
from tvm.tir.schedule.schedule import BlockRV from tvm.tir.schedule.schedule import BlockRV
from tvm.ir import structural_equal from tvm.ir import structural_equal
from tvm import arith, tir from tvm import arith, tir
......
...@@ -64,7 +64,7 @@ def get_cc(): ...@@ -64,7 +64,7 @@ def get_cc():
return None return None
@functools.lru_cache(maxsize=None) @functools.cache
def get_cplus_compiler(): def get_cplus_compiler():
"""Return the path to the default C/C++ compiler. """Return the path to the default C/C++ compiler.
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
"""Wrapping functions to bridge frameworks with DLPack support to TVM""" """Wrapping functions to bridge frameworks with DLPack support to TVM"""
from tvm.runtime import ndarray from tvm import runtime
def convert_func(tvm_func, tensor_type, to_dlpack_func): def convert_func(tvm_func, tensor_type, to_dlpack_func):
...@@ -49,9 +49,9 @@ def convert_func(tvm_func, tensor_type, to_dlpack_func): ...@@ -49,9 +49,9 @@ def convert_func(tvm_func, tensor_type, to_dlpack_func):
torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2,
torch.float8_e5m2fnuz torch.float8_e5m2fnuz
}: }:
return ndarray.from_dlpack(to_dlpack_func(arg.view(torch.int8)))._create_view( return runtime.from_dlpack(to_dlpack_func(arg.view(torch.int8)))._create_view(
arg.shape, dtype=float8_dtype_map[arg.dtype]) arg.shape, dtype=float8_dtype_map[arg.dtype])
return ndarray.from_dlpack(to_dlpack_func(arg)) return runtime.from_dlpack(to_dlpack_func(arg))
return arg return arg
def _wrapper(*args): def _wrapper(*args):
......
...@@ -9,7 +9,7 @@ from __future__ import absolute_import as _abs ...@@ -9,7 +9,7 @@ from __future__ import absolute_import as _abs
import subprocess import subprocess
import tvm.ffi import tvm_ffi
from tvm.contrib import utils from tvm.contrib import utils
from tvm.base import py_str from tvm.base import py_str
...@@ -97,7 +97,7 @@ def compile_hip(code, ...@@ -97,7 +97,7 @@ def compile_hip(code,
return data return data
@tvm.ffi.register_func("tilelang_callback_hip_compile", override=True) @tvm_ffi.register_global_func("tilelang_callback_hip_compile", override=True)
def tilelang_callback_hip_compile(code, target): def tilelang_callback_hip_compile(code, target):
"""use hipcc to generate fatbin code for better optimization""" """use hipcc to generate fatbin code for better optimization"""
hsaco = compile_hip(code, target_format="hsaco") hsaco = compile_hip(code, target_format="hsaco")
......
...@@ -7,9 +7,12 @@ from __future__ import annotations ...@@ -7,9 +7,12 @@ from __future__ import annotations
import os import os
import subprocess import subprocess
import warnings import warnings
from tilelang.env import CUDA_HOME import contextlib
from tilelang.env import CUDA_HOME, CUTLASS_INCLUDE_DIR, TILELANG_TEMPLATE_PATH
import tvm.ffi import shutil
import tempfile
import tvm_ffi
from tilelang import tvm as tvm
from tvm.target import Target from tvm.target import Target
from tvm.base import py_str from tvm.base import py_str
...@@ -125,6 +128,154 @@ def compile_cuda(code, ...@@ -125,6 +128,154 @@ def compile_cuda(code,
return data return data
def default_compile_options(compile_flags: list[str] | None = None) -> list[str]:
"""
Build a set of default NVCC compile options for TileLang generated sources.
Includes C++ standard and common include paths (TileLang templates, CUTLASS,
CUDA include). Merges user-provided compile flags if given.
Parameters
----------
compile_flags : Optional[List[str]]
Additional flags to include. Items are split on whitespace.
Returns
-------
List[str]
A list of flags suitable for NVCC's command line.
"""
options: list[str] = ["-std=c++17"]
try:
if TILELANG_TEMPLATE_PATH:
options.append(f"-I{TILELANG_TEMPLATE_PATH}")
except Exception:
pass
try:
if CUTLASS_INCLUDE_DIR:
options.append(f"-I{CUTLASS_INCLUDE_DIR}")
except Exception:
pass
try:
if CUDA_HOME:
options.append(f"-I{os.path.join(CUDA_HOME, 'include')}")
except Exception:
pass
# Preserve user flags exactly, including repeated tokens required by NVCC
# (e.g., multiple "-gencode" pairs or repeated "-Xcompiler" entries).
if compile_flags:
import shlex
for flag in compile_flags:
# Split each string like a shell would, preserving quoted args
tokens = shlex.split(flag) if isinstance(flag, str) else [str(flag)]
options.extend(tokens)
return options
def get_ptx_from_source(code: str,
compile_flags: list[str] | None = None,
verbose: bool = False) -> str:
"""
Compile CUDA C++ source to PTX using NVCC and return as text.
Parameters
----------
code : str
CUDA C++ kernel source code.
compile_flags : Optional[List[str]]
Additional flags merged with defaults.
verbose : bool
Print NVCC output when True.
Returns
-------
str
PTX text.
"""
opts = default_compile_options(compile_flags)
ptx_bytes = compile_cuda(code, target_format="ptx", options=opts, verbose=verbose)
try:
return ptx_bytes.decode("utf-8")
except Exception:
return str(ptx_bytes)
def _find_tool(name: str) -> str | None:
"""Find a CUDA binary in PATH or under CUDA_HOME/bin."""
path = shutil.which(name)
if path:
return path
if CUDA_HOME:
candidate = os.path.join(CUDA_HOME, "bin", name)
if os.path.exists(candidate):
return candidate
return None
def get_sass_from_source(code: str,
compile_flags: list[str] | None = None,
verbose: bool = False) -> str:
"""
Compile CUDA C++ source to CUBIN and disassemble to SASS.
Uses nvdisasm if available; otherwise falls back to cuobjdump.
Parameters
----------
code : str
CUDA C++ kernel source code.
compile_flags : Optional[List[str]]
Additional flags merged with defaults.
verbose : bool
Print tool outputs when True.
Returns
-------
str
SASS text.
"""
opts = default_compile_options(compile_flags)
cubin_bytes = compile_cuda(code, target_format="cubin", options=opts, verbose=verbose)
# Write to a temp .cubin file
with tempfile.NamedTemporaryFile(suffix=".cubin", delete=False) as tmp:
tmp.write(cubin_bytes)
cubin_path = tmp.name
# Try disassembly tools (prefer nvdisasm, fallback cuobjdump)
cand_nvdisasm = _find_tool("nvdisasm")
cand_cuobjdump = _find_tool("cuobjdump")
if not cand_nvdisasm and not cand_cuobjdump:
raise RuntimeError(
"Cannot find 'nvdisasm' or 'cuobjdump'. Please ensure CUDA toolkit is installed and in PATH."
)
last_err: str | None = None
try:
# Attempt nvdisasm first
tools_to_try = []
if cand_nvdisasm:
tools_to_try.append(("nvdisasm", [cand_nvdisasm, cubin_path]))
if cand_cuobjdump:
tools_to_try.append(("cuobjdump", [cand_cuobjdump, "--dump-sass", cubin_path]))
for tool_name, cmd in tools_to_try:
proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
out, _ = proc.communicate()
text = py_str(out)
if verbose:
print(f"[{tool_name}] output:\n{text}")
if proc.returncode == 0 and text.strip():
return text
last_err = f"{tool_name} rc={proc.returncode}, output:\n{text}"
# If we reach here, all attempts failed
raise RuntimeError(f"SASS disassembly failed. Tried tools: "
f"{', '.join(name for name, _ in tools_to_try)}\n{last_err or ''}")
finally:
with contextlib.suppress(Exception):
os.remove(cubin_path)
def find_cuda_path(): def find_cuda_path():
"""Utility function to find cuda path """Utility function to find cuda path
...@@ -182,14 +333,14 @@ def get_cuda_version(cuda_path=None): ...@@ -182,14 +333,14 @@ def get_cuda_version(cuda_path=None):
raise RuntimeError("Cannot read cuda version file") raise RuntimeError("Cannot read cuda version file")
@tvm.ffi.register_func("tilelang_callback_cuda_compile", override=True) @tvm_ffi.register_global_func("tilelang_callback_cuda_compile", override=True)
def tilelang_callback_cuda_compile(code, target): # pylint: disable=unused-argument def tilelang_callback_cuda_compile(code, target): # pylint: disable=unused-argument
"""use nvcc to generate fatbin code for better optimization""" """use nvcc to generate fatbin code for better optimization"""
ptx = compile_cuda(code, target_format="fatbin") ptx = compile_cuda(code, target_format="fatbin")
return ptx return ptx
@tvm.ffi.register_func("tilelang_callback_libdevice_path", override=True) @tvm_ffi.register_global_func("tilelang_callback_libdevice_path", override=True)
def find_libdevice_path(arch): def find_libdevice_path(arch):
"""Utility function to find libdevice """Utility function to find libdevice
...@@ -254,7 +405,7 @@ def callback_libdevice_path(arch): ...@@ -254,7 +405,7 @@ def callback_libdevice_path(arch):
return "" return ""
@tvm.ffi.register_func("tvm.contrib.nvcc.get_compute_version", override=True) @tvm_ffi.register_global_func("tvm.contrib.nvcc.get_compute_version", override=True)
def get_target_compute_version(target=None): def get_target_compute_version(target=None):
"""Utility function to get compute capability of compilation target. """Utility function to get compute capability of compilation target.
...@@ -400,7 +551,7 @@ def have_cudagraph(): ...@@ -400,7 +551,7 @@ def have_cudagraph():
return False return False
@tvm.ffi.register_func("tvm.contrib.nvcc.supports_bf16", override=True) @tvm_ffi.register_global_func("tvm.contrib.nvcc.supports_bf16", override=True)
def have_bf16(compute_version): def have_bf16(compute_version):
"""Either bf16 support is provided in the compute capability or not """Either bf16 support is provided in the compute capability or not
...@@ -413,7 +564,7 @@ def have_bf16(compute_version): ...@@ -413,7 +564,7 @@ def have_bf16(compute_version):
return major >= 8 return major >= 8
@tvm.ffi.register_func("tvm.contrib.nvcc.supports_fp8", override=True) @tvm_ffi.register_global_func("tvm.contrib.nvcc.supports_fp8", override=True)
def have_fp8(compute_version): def have_fp8(compute_version):
"""Whether fp8 support is provided in the specified compute capability or not """Whether fp8 support is provided in the specified compute capability or not
...@@ -430,7 +581,7 @@ def have_fp8(compute_version): ...@@ -430,7 +581,7 @@ def have_fp8(compute_version):
return any(conditions) return any(conditions)
@tvm.ffi.register_func("tvm.contrib.nvcc.supports_tma", override=True) @tvm_ffi.register_global_func("tvm.contrib.nvcc.supports_tma", override=True)
def have_tma(target): def have_tma(target):
"""Whether TMA support is provided in the specified compute capability or not """Whether TMA support is provided in the specified compute capability or not
......
...@@ -21,7 +21,7 @@ import subprocess ...@@ -21,7 +21,7 @@ import subprocess
import os import os
from os.path import join, exists from os.path import join, exists
import tvm.ffi import tvm_ffi
from tvm.base import py_str from tvm.base import py_str
import tvm.runtime import tvm.runtime
import tvm.target import tvm.target
...@@ -100,7 +100,7 @@ def rocm_link(in_file, out_file, lld=None): ...@@ -100,7 +100,7 @@ def rocm_link(in_file, out_file, lld=None):
raise RuntimeError(msg) raise RuntimeError(msg)
@tvm.ffi.register_func("tvm_callback_rocm_link", override=True) @tvm_ffi.register_global_func("tvm_callback_rocm_link", override=True)
def callback_rocm_link(obj_bin): def callback_rocm_link(obj_bin):
"""Links object file generated from LLVM to HSA Code Object """Links object file generated from LLVM to HSA Code Object
...@@ -124,7 +124,7 @@ def callback_rocm_link(obj_bin): ...@@ -124,7 +124,7 @@ def callback_rocm_link(obj_bin):
return cobj_bin return cobj_bin
@tvm.ffi.register_func("tvm_callback_rocm_bitcode_path", override=True) @tvm_ffi.register_global_func("tvm_callback_rocm_bitcode_path", override=True)
def callback_rocm_bitcode_path(rocdl_dir=None): def callback_rocm_bitcode_path(rocdl_dir=None):
"""Utility function to find ROCm device library bitcodes """Utility function to find ROCm device library bitcodes
...@@ -226,8 +226,11 @@ def have_matrixcore(compute_version=None): ...@@ -226,8 +226,11 @@ def have_matrixcore(compute_version=None):
return False return False
@tvm.ffi.register_func("tvm_callback_rocm_get_arch", override=True)
def get_rocm_arch(rocm_path="/opt/dtk"): @tvm_ffi.register_global_func("tvm_callback_rocm_get_arch", override=True)
def get_rocm_arch(rocm_path="/opt/rocm"):
# @tvm.ffi.register_func("tvm_callback_rocm_get_arch", override=True)
# def get_rocm_arch(rocm_path="/opt/dtk"):
"""Utility function to get the AMD GPU architecture """Utility function to get the AMD GPU architecture
Parameters Parameters
......
from __future__ import annotations from __future__ import annotations
from typing import Callable from typing import Callable
from tvm import register_func import tvm_ffi
from tvm.target import Target from tvm.target import Target
...@@ -12,7 +12,7 @@ def register_cuda_postproc(func: Callable[[str, Target], str], override: bool = ...@@ -12,7 +12,7 @@ def register_cuda_postproc(func: Callable[[str, Target], str], override: bool =
and returns the processed code (str). and returns the processed code (str).
override: Whether to override existing registered function. Defaults to True. override: Whether to override existing registered function. Defaults to True.
""" """
register_func("tilelang_callback_cuda_postproc", f=func, override=override) tvm_ffi.register_global_func("tilelang_callback_cuda_postproc", f=func, override=override)
def register_hip_postproc(func: Callable[[str, Target], str], override: bool = True): def register_hip_postproc(func: Callable[[str, Target], str], override: bool = True):
...@@ -23,7 +23,7 @@ def register_hip_postproc(func: Callable[[str, Target], str], override: bool = T ...@@ -23,7 +23,7 @@ def register_hip_postproc(func: Callable[[str, Target], str], override: bool = T
and returns the processed code (str). and returns the processed code (str).
override: Whether to override existing registered function. Defaults to True. override: Whether to override existing registered function. Defaults to True.
""" """
register_func("tilelang_callback_hip_postproc", f=func, override=override) tvm_ffi.register_global_func("tilelang_callback_hip_postproc", f=func, override=override)
def register_cuda_postproc_callback(func: Callable | bool = None, override: bool = True): def register_cuda_postproc_callback(func: Callable | bool = None, override: bool = True):
......
...@@ -7,6 +7,7 @@ from typing import Callable ...@@ -7,6 +7,7 @@ from typing import Callable
import tilelang.transform import tilelang.transform
from tilelang import tvm as tvm from tilelang import tvm as tvm
from tvm import tir from tvm import tir
import tvm_ffi
from tvm.ir import CallingConv from tvm.ir import CallingConv
from tvm.target import Target from tvm.target import Target
from tilelang.contrib import hipcc, nvcc from tilelang.contrib import hipcc, nvcc
...@@ -52,7 +53,7 @@ def get_host_call(is_device_c: bool = False) -> Callable[[tir.PrimFunc], bool]: ...@@ -52,7 +53,7 @@ def get_host_call(is_device_c: bool = False) -> Callable[[tir.PrimFunc], bool]:
return lambda func: not get_device_call(is_device_c)(func) return lambda func: not get_device_call(is_device_c)(func)
@tvm.register_func("tilelang_callback_cuda_compile", override=True) @tvm_ffi.register_global_func("tilelang_callback_cuda_compile", override=True)
def tilelang_callback_cuda_compile(code, target): def tilelang_callback_cuda_compile(code, target):
project_root = osp.join(osp.dirname(__file__), "../..") project_root = osp.join(osp.dirname(__file__), "../..")
if "TL_TEMPLATE_PATH" in os.environ: if "TL_TEMPLATE_PATH" in os.environ:
...@@ -89,7 +90,7 @@ def tilelang_callback_cuda_compile(code, target): ...@@ -89,7 +90,7 @@ def tilelang_callback_cuda_compile(code, target):
return ptx return ptx
@tvm.register_func("tilelang_callback_hip_compile", override=True) @tvm_ffi.register_global_func("tilelang_callback_hip_compile", override=True)
def tilelang_callback_hip_compile(code, target): def tilelang_callback_hip_compile(code, target):
project_root = osp.join(osp.dirname(__file__), "../..") project_root = osp.join(osp.dirname(__file__), "../..")
tl_template_path = osp.abspath(osp.join(project_root, "src")) tl_template_path = osp.abspath(osp.join(project_root, "src"))
...@@ -182,7 +183,7 @@ def device_codegen_without_compile(device_mod: tvm.IRModule, target: Target) -> ...@@ -182,7 +183,7 @@ def device_codegen_without_compile(device_mod: tvm.IRModule, target: Target) ->
elif target.kind.name == "llvm": elif target.kind.name == "llvm":
device_mod = tvm.ffi.get_global_func("target.build.llvm")(device_mod, target) device_mod = tvm.ffi.get_global_func("target.build.llvm")(device_mod, target)
elif target.kind.name == "webgpu": elif target.kind.name == "webgpu":
device_mod = tvm.ffi.get_global_func("target.build.tilelang_webgpu")(device_mod, target) device_mod = tvm.ffi.get_global_func("target.build.webgpu")(device_mod, target)
elif target.kind.name == "metal": elif target.kind.name == "metal":
device_mod = tvm.ffi.get_global_func("target.build.metal")(device_mod, target) device_mod = tvm.ffi.get_global_func("target.build.metal")(device_mod, target)
else: else:
...@@ -241,6 +242,6 @@ def lower( ...@@ -241,6 +242,6 @@ def lower(
host_mod = host_codegen(host_mod, target_host) host_mod = host_codegen(host_mod, target_host)
host_mod.import_module(codegen_mod) host_mod.import_module(codegen_mod)
return CompiledArtifact( return CompiledArtifact(
host_mod, device_mod, params, codegen_mod.get_source(), rt_mod=host_mod) host_mod, device_mod, params, codegen_mod.inspect_source(), rt_mod=host_mod)
return CompiledArtifact(host_mod, device_mod, params, codegen_mod.get_source()) return CompiledArtifact(host_mod, device_mod, params, codegen_mod.inspect_source())
...@@ -96,6 +96,8 @@ def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule: ...@@ -96,6 +96,8 @@ def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule:
mod = tilelang.transform.LetInline()(mod) mod = tilelang.transform.LetInline()(mod)
# Add wrapper for single buf store # Add wrapper for single buf store
mod = tilelang.transform.AddWrapperForSingleBufStore()(mod) mod = tilelang.transform.AddWrapperForSingleBufStore()(mod)
# Normalize negative indices to canonical non-negative form
mod = tilelang.transform.LegalizeNegativeIndex()(mod)
# Inject assumes to speedup tvm prover # Inject assumes to speedup tvm prover
mod = tilelang.transform.InjectAssumes()(mod) mod = tilelang.transform.InjectAssumes()(mod)
# Simplify the IR expressions # Simplify the IR expressions
...@@ -118,8 +120,6 @@ def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule: ...@@ -118,8 +120,6 @@ def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule:
# TODO(lei): return to tir pass when kSymbolicBound simplification # TODO(lei): return to tir pass when kSymbolicBound simplification
# is merged into tvm. # is merged into tvm.
mod = tilelang.transform.Simplify()(mod) mod = tilelang.transform.Simplify()(mod)
# Try to vectorize loop with dynamic shape
mod = tilelang.transform.LoopVectorizeDynamic()(mod)
return mod return mod
......
...@@ -236,6 +236,10 @@ class Environment: ...@@ -236,6 +236,10 @@ class Environment:
"1") # print kernel name on compile "1") # print kernel name on compile
TILELANG_CLEAR_CACHE = EnvVar("TILELANG_CLEAR_CACHE", "0") # clear cache automatically if set TILELANG_CLEAR_CACHE = EnvVar("TILELANG_CLEAR_CACHE", "0") # clear cache automatically if set
# Kernel selection options
# Default to GEMM v2; set to "1"/"true"/"yes"/"on" to force v1
TILELANG_USE_GEMM_V1 = EnvVar("TILELANG_USE_GEMM_V1", "0")
# Auto-tuning settings # Auto-tuning settings
TILELANG_AUTO_TUNING_CPU_UTILITIES = EnvVar("TILELANG_AUTO_TUNING_CPU_UTILITIES", TILELANG_AUTO_TUNING_CPU_UTILITIES = EnvVar("TILELANG_AUTO_TUNING_CPU_UTILITIES",
"0.9") # percent of CPUs used "0.9") # percent of CPUs used
...@@ -274,6 +278,14 @@ class Environment: ...@@ -274,6 +278,14 @@ class Environment:
def is_print_on_compilation_enabled(self) -> bool: def is_print_on_compilation_enabled(self) -> bool:
return self.TILELANG_PRINT_ON_COMPILATION.lower() in ("1", "true", "yes", "on") return self.TILELANG_PRINT_ON_COMPILATION.lower() in ("1", "true", "yes", "on")
def use_gemm_v1(self) -> bool:
"""Return True if GEMM v1 should be used based on env.
Controlled by `TILELANG_USE_GEMM_V1`. Truthy values are one of
{"1", "true", "yes", "on"} (case-insensitive).
"""
return str(self.TILELANG_USE_GEMM_V1).lower() in ("1", "true", "yes", "on")
# Instantiate as a global configuration object # Instantiate as a global configuration object
env = Environment() env = Environment()
...@@ -297,12 +309,11 @@ def prepend_pythonpath(path): ...@@ -297,12 +309,11 @@ def prepend_pythonpath(path):
if env.TVM_IMPORT_PYTHON_PATH is not None: if env.TVM_IMPORT_PYTHON_PATH is not None:
prepend_pythonpath(env.TVM_IMPORT_PYTHON_PATH) prepend_pythonpath(env.TVM_IMPORT_PYTHON_PATH)
else: else:
tvm_path = os.path.join(THIRD_PARTY_ROOT, "tvm") tvm_path = os.path.join(THIRD_PARTY_ROOT, 'tvm', 'python')
assert os.path.exists(tvm_path), tvm_path assert os.path.exists(tvm_path), tvm_path
if tvm_path not in sys.path: if tvm_path not in sys.path:
tvm_python_binding = os.path.join(tvm_path, 'python') prepend_pythonpath(tvm_path)
prepend_pythonpath(tvm_python_binding) env.TVM_IMPORT_PYTHON_PATH = tvm_path
env.TVM_IMPORT_PYTHON_PATH = tvm_python_binding
if os.environ.get("TVM_LIBRARY_PATH") is None: if os.environ.get("TVM_LIBRARY_PATH") is None:
os.environ['TVM_LIBRARY_PATH'] = env.TVM_LIBRARY_PATH = os.pathsep.join(TL_LIBS) os.environ['TVM_LIBRARY_PATH'] = env.TVM_LIBRARY_PATH = os.pathsep.join(TL_LIBS)
......
...@@ -2,10 +2,32 @@ from __future__ import annotations ...@@ -2,10 +2,32 @@ from __future__ import annotations
from tilelang import tvm as tvm from tilelang import tvm as tvm
import tilelang.language as T import tilelang.language as T
from tvm import DataType from tvm import DataType
from tvm.tir import PrimExpr from tvm.tir import PrimExpr, IndexMap, Buffer, Var, BufferRegion
from tvm.runtime import convert from tvm.runtime import convert
from .utils import ( from .utils import (
mfma_store_index_map,) mfma_store_index_map,)
from typing import Literal, Callable
from tilelang.utils import is_fragment
from tilelang.utils.language import to_buffer_region
from .mfma_layout import (
shared_16x4_to_local_64x1_layout_A,
shared_4x16_to_local_64x1_layout_B,
shared_16x16_to_local_64x4_layout_A,
shared_16x16_to_local_64x4_layout_B,
shared_16x32_to_local_64x8_layout_A,
shared_16x32_to_local_64x8_layout_B,
shared_16x64_to_local_64x16_layout_A,
shared_16x64_to_local_64x16_layout_B,
thread_id_shared_access_64x1_to_16x4_layout_A,
thread_id_shared_access_64x1_to_4x16_layout_B,
thread_id_shared_access_64x4_to_16x16_layout_A,
thread_id_shared_access_64x4_to_16x16_layout_B,
thread_id_shared_access_64x8_to_16x32_layout_A,
thread_id_shared_access_64x8_to_16x32_layout_B,
thread_id_shared_access_64x16_to_16x64_layout_A,
thread_id_shared_access_64x16_to_16x64_layout_B,
)
lift = convert lift = convert
...@@ -53,6 +75,7 @@ class MatrixCoreIntrinEmitter: ...@@ -53,6 +75,7 @@ class MatrixCoreIntrinEmitter:
k_pack: int | None = None, k_pack: int | None = None,
is_m_first: bool | None = False, is_m_first: bool | None = False,
b_preshuffle: bool | None = False, b_preshuffle: bool | None = False,
thread_var: Var | None = None,
): ):
self.a_dtype = a_dtype self.a_dtype = a_dtype
self.b_dtype = b_dtype self.b_dtype = b_dtype
...@@ -79,6 +102,7 @@ class MatrixCoreIntrinEmitter: ...@@ -79,6 +102,7 @@ class MatrixCoreIntrinEmitter:
self.reduce_k = reduce_k self.reduce_k = reduce_k
self.threads = (self.WARP_SIZE * (block_row_warps * block_col_warps) * reduce_k) self.threads = (self.WARP_SIZE * (block_row_warps * block_col_warps) * reduce_k)
self.num_elems_per_byte = num_elems_per_byte self.num_elems_per_byte = num_elems_per_byte
self.thread_var = thread_var
def _initialize_k_dim(self, a_dtype="float16"): def _initialize_k_dim(self, a_dtype="float16"):
if isinstance(a_dtype, str): if isinstance(a_dtype, str):
...@@ -115,6 +139,7 @@ class MatrixCoreIntrinEmitter: ...@@ -115,6 +139,7 @@ class MatrixCoreIntrinEmitter:
}[out_dtype] }[out_dtype]
in_dtype_abbrv = { in_dtype_abbrv = {
"bfloat16": "bf16",
"float16": "f16", "float16": "f16",
"float32": "f32", "float32": "f32",
"int8": "i8", "int8": "i8",
...@@ -126,6 +151,9 @@ class MatrixCoreIntrinEmitter: ...@@ -126,6 +151,9 @@ class MatrixCoreIntrinEmitter:
self.mfma_suffix = f"{out_dtype_abbrv}_{M_DIM}x{N_DIM}x{k_dim}_fp8_fp8" self.mfma_suffix = f"{out_dtype_abbrv}_{M_DIM}x{N_DIM}x{k_dim}_fp8_fp8"
elif in_dtype_abbrv == "i8": elif in_dtype_abbrv == "i8":
self.mfma_suffix = f"{out_dtype_abbrv}_{M_DIM}x{N_DIM}x{k_dim}_i8" self.mfma_suffix = f"{out_dtype_abbrv}_{M_DIM}x{N_DIM}x{k_dim}_i8"
elif in_dtype_abbrv == "bf16":
# HIP intrinsic uses ...x{K}bf16_1k without an underscore before bf16
self.mfma_suffix = f"{out_dtype_abbrv}_{M_DIM}x{N_DIM}x{k_dim}bf16_1k"
else: else:
self.mfma_suffix = f"{out_dtype_abbrv}_{M_DIM}x{N_DIM}x{k_dim}{in_dtype_abbrv}" self.mfma_suffix = f"{out_dtype_abbrv}_{M_DIM}x{N_DIM}x{k_dim}{in_dtype_abbrv}"
...@@ -147,24 +175,6 @@ class MatrixCoreIntrinEmitter: ...@@ -147,24 +175,6 @@ class MatrixCoreIntrinEmitter:
self.b_preshuffle = b_preshuffle self.b_preshuffle = b_preshuffle
def get_ldmatrix_index_map(self, is_b=False): def get_ldmatrix_index_map(self, is_b=False):
from .mfma_layout import (
shared_16x4_to_local_64x1_layout_A,
shared_4x16_to_local_64x1_layout_B,
shared_16x16_to_local_64x4_layout_A,
shared_16x16_to_local_64x4_layout_B,
shared_16x32_to_local_64x8_layout_A,
shared_16x32_to_local_64x8_layout_B,
shared_16x64_to_local_64x16_layout_A,
shared_16x64_to_local_64x16_layout_B,
thread_id_shared_access_64x1_to_16x4_layout_A,
thread_id_shared_access_64x1_to_4x16_layout_B,
thread_id_shared_access_64x4_to_16x16_layout_A,
thread_id_shared_access_64x4_to_16x16_layout_B,
thread_id_shared_access_64x8_to_16x32_layout_A,
thread_id_shared_access_64x8_to_16x32_layout_B,
thread_id_shared_access_64x16_to_16x64_layout_A,
thread_id_shared_access_64x16_to_16x64_layout_B,
)
k_dim = self.k_dim * self.k_pack k_dim = self.k_dim * self.k_pack
transposed = self.a_transposed if not is_b else self.b_transposed transposed = self.a_transposed if not is_b else self.b_transposed
...@@ -200,6 +210,22 @@ class MatrixCoreIntrinEmitter: ...@@ -200,6 +210,22 @@ class MatrixCoreIntrinEmitter:
return index_map, reverse_index_map return index_map, reverse_index_map
def get_store_index_map(self, inverse: bool = False) -> IndexMap:
warp_size, local_size_c = self.WARP_SIZE, self.local_size_out
index_map = IndexMap.from_func(mfma_store_index_map, index_dtype="int32")
if not inverse:
return index_map
inverse_index_map = index_map.inverse([warp_size, local_size_c])
return inverse_index_map
def get_thread_binding(self):
if self.thread_var is None:
current_frame = T.KernelLaunchFrame.Current()
assert current_frame is not None, "Must be called in a T.Kernel Frame"
return current_frame.get_thread_binding()
else:
return self.thread_var
def extract_thread_binding(self, def extract_thread_binding(self,
thread_id, thread_id,
is_m_first=None) -> tuple[PrimExpr, PrimExpr, PrimExpr]: is_m_first=None) -> tuple[PrimExpr, PrimExpr, PrimExpr]:
...@@ -229,7 +255,7 @@ class MatrixCoreIntrinEmitter: ...@@ -229,7 +255,7 @@ class MatrixCoreIntrinEmitter:
(WARP_SIZE * block_row_warps)) % block_col_warps, (WARP_SIZE * block_row_warps)) % block_col_warps,
return lane_id, warp_n, warp_m return lane_id, warp_n, warp_m
def ldmatrix_a(self, A_local_buf, A_shared_buf, ki, rk=0): def ldmatrix_a(self, A_local_buf, A_shared_buf: Buffer | BufferRegion, ki, rk=0):
warp_row_tiles = self.warp_row_tiles warp_row_tiles = self.warp_row_tiles
warp_rows = self.warp_rows warp_rows = self.warp_rows
chunk = self.chunk chunk = self.chunk
...@@ -238,10 +264,15 @@ class MatrixCoreIntrinEmitter: ...@@ -238,10 +264,15 @@ class MatrixCoreIntrinEmitter:
local_size_a = self.local_size_a local_size_a = self.local_size_a
k_pack = self.k_pack k_pack = self.k_pack
is_transposed = self.a_transposed is_transposed = self.a_transposed
current_frame = T.KernelLaunchFrame.Current() thread_binding = self.get_thread_binding()
thread_binding = current_frame.get_thread_binding()
_, reverse_index_map = self.get_ldmatrix_index_map(is_b=False) _, reverse_index_map = self.get_ldmatrix_index_map(is_b=False)
# legalize shared buffer to region
A_region = to_buffer_region(A_shared_buf)
A_buf = A_region.buffer
A_base0 = A_region.region[-2].min
A_base1 = A_region.region[-1].min
@T.macro @T.macro
def _warp_ldmatrix_a( def _warp_ldmatrix_a(
A_local_buf, A_local_buf,
...@@ -257,20 +288,20 @@ class MatrixCoreIntrinEmitter: ...@@ -257,20 +288,20 @@ class MatrixCoreIntrinEmitter:
row, col = T.meta_var(reverse_index_map(tx, local_id)) row, col = T.meta_var(reverse_index_map(tx, local_id))
l, r = (rk * chunk + ki * (k_pack * micro_size_k), l, r = (rk * chunk + ki * (k_pack * micro_size_k),
warp_m * warp_row_tiles + i * micro_size_x) warp_m * warp_row_tiles + i * micro_size_x)
A_local_buf[i * k_pack * local_size_a + local_id] = A_shared_buf[l + row, A_local_buf[i * k_pack * local_size_a + local_id] = A_buf[A_base0 + l + row,
r + col] A_base1 + r + col]
else: else:
for i in T.serial(warp_rows): for i in T.serial(warp_rows):
for local_id in T.vectorized(k_pack * local_size_a): for local_id in T.vectorized(k_pack * local_size_a):
row, col = T.meta_var(reverse_index_map(tx, local_id)) row, col = T.meta_var(reverse_index_map(tx, local_id))
l, r = (warp_m * warp_row_tiles + i * micro_size_x, l, r = (warp_m * warp_row_tiles + i * micro_size_x,
rk * chunk + ki * (k_pack * micro_size_k)) rk * chunk + ki * (k_pack * micro_size_k))
A_local_buf[i * k_pack * local_size_a + local_id] = A_shared_buf[l + row, A_local_buf[i * k_pack * local_size_a + local_id] = A_buf[A_base0 + l + row,
r + col] A_base1 + r + col]
return _warp_ldmatrix_a(A_local_buf, A_shared_buf, ki, thread_binding, rk) return _warp_ldmatrix_a(A_local_buf, A_shared_buf, ki, thread_binding, rk)
def ldmatrix_b(self, B_local_buf, B_shared_buf, ki, rk=0): def ldmatrix_b(self, B_local_buf, B_shared_buf: Buffer | BufferRegion, ki, rk=0):
warp_col_tiles = self.warp_col_tiles warp_col_tiles = self.warp_col_tiles
warp_cols = self.warp_cols warp_cols = self.warp_cols
chunk = self.chunk chunk = self.chunk
...@@ -279,10 +310,15 @@ class MatrixCoreIntrinEmitter: ...@@ -279,10 +310,15 @@ class MatrixCoreIntrinEmitter:
local_size_b = self.local_size_b local_size_b = self.local_size_b
k_pack = self.k_pack k_pack = self.k_pack
is_transposed = self.b_transposed is_transposed = self.b_transposed
current_frame = T.KernelLaunchFrame.Current() thread_binding = self.get_thread_binding()
thread_binding = current_frame.get_thread_binding()
_, reverse_index_map = self.get_ldmatrix_index_map(is_b=True) _, reverse_index_map = self.get_ldmatrix_index_map(is_b=True)
# legalize shared buffer to region
B_region = to_buffer_region(B_shared_buf)
B_buf = B_region.buffer
B_base0 = B_region.region[-2].min
B_base1 = B_region.region[-1].min
@T.macro @T.macro
def _warp_ldmatrix_b( def _warp_ldmatrix_b(
B_local_buf, B_local_buf,
...@@ -300,8 +336,8 @@ class MatrixCoreIntrinEmitter: ...@@ -300,8 +336,8 @@ class MatrixCoreIntrinEmitter:
warp_n * warp_col_tiles + j * micro_size_y, warp_n * warp_col_tiles + j * micro_size_y,
rk * chunk + ki * (k_pack * micro_size_k), rk * chunk + ki * (k_pack * micro_size_k),
) )
B_local_buf[j * k_pack * local_size_b + local_id] = B_shared_buf[l + row, B_local_buf[j * k_pack * local_size_b + local_id] = B_buf[B_base0 + l + row,
r + col] B_base1 + r + col]
else: else:
for j in T.serial(warp_cols): for j in T.serial(warp_cols):
...@@ -311,12 +347,16 @@ class MatrixCoreIntrinEmitter: ...@@ -311,12 +347,16 @@ class MatrixCoreIntrinEmitter:
rk * chunk + ki * (k_pack * micro_size_k), rk * chunk + ki * (k_pack * micro_size_k),
warp_n * warp_col_tiles + j * micro_size_y, warp_n * warp_col_tiles + j * micro_size_y,
) )
B_local_buf[j * k_pack * local_size_b + local_id] = B_shared_buf[l + row, B_local_buf[j * k_pack * local_size_b + local_id] = B_buf[B_base0 + l + row,
r + col] B_base1 + r + col]
return _warp_ldmatrix_b(B_local_buf, B_shared_buf, ki, thread_binding, rk) return _warp_ldmatrix_b(B_local_buf, B_shared_buf, ki, thread_binding, rk)
def mfma(self, A_local_buf, B_local_buf, C_local_buf): def mfma(self,
A_local_buf: Buffer,
B_local_buf: Buffer,
C_local_buf: Buffer,
k_inner: PrimExpr | None = 0):
warp_rows = self.warp_rows warp_rows = self.warp_rows
warp_cols = self.warp_cols warp_cols = self.warp_cols
local_size_a = self.local_size_a local_size_a = self.local_size_a
...@@ -329,8 +369,13 @@ class MatrixCoreIntrinEmitter: ...@@ -329,8 +369,13 @@ class MatrixCoreIntrinEmitter:
compute_b_dtype = b_dtype if local_size_b == 1 else f"{b_dtype}x{local_size_b}" compute_b_dtype = b_dtype if local_size_b == 1 else f"{b_dtype}x{local_size_b}"
compute_out_dtype = out_dtype if local_size_out == 1 else f"{out_dtype}x{local_size_out}" compute_out_dtype = out_dtype if local_size_out == 1 else f"{out_dtype}x{local_size_out}"
a_is_fragment = is_fragment(A_local_buf)
b_is_fragment = is_fragment(B_local_buf)
a_local_stride: PrimExpr = k_inner * warp_rows * local_size_a if a_is_fragment else 0
b_local_stride: PrimExpr = k_inner * warp_cols * local_size_b if b_is_fragment else 0
@T.macro @T.macro
def _warp_mma(A_local_buf, B_local_buf, C_local_buf): def _warp_mfma(A_local_buf, B_local_buf, C_local_buf):
for kp, i, j in T.grid(k_pack, warp_rows, warp_cols): for kp, i, j in T.grid(k_pack, warp_rows, warp_cols):
T.tvm_mfma( T.tvm_mfma(
mfma_suffix, mfma_suffix,
...@@ -340,15 +385,15 @@ class MatrixCoreIntrinEmitter: ...@@ -340,15 +385,15 @@ class MatrixCoreIntrinEmitter:
compute_b_dtype, compute_b_dtype,
compute_out_dtype, compute_out_dtype,
B_local_buf.data, B_local_buf.data,
((j * k_pack + kp) * local_size_b) // local_size_b, (b_local_stride + (j * k_pack + kp) * local_size_b) // local_size_b,
A_local_buf.data, A_local_buf.data,
((i * k_pack + kp) * local_size_a) // local_size_a, (a_local_stride + (i * k_pack + kp) * local_size_a) // local_size_a,
C_local_buf.data, C_local_buf.data,
(i * warp_cols * local_size_out + j * local_size_out) // local_size_out, (i * warp_cols * local_size_out + j * local_size_out) // local_size_out,
dtype=compute_out_dtype, dtype=compute_out_dtype,
) )
return _warp_mma(A_local_buf, B_local_buf, C_local_buf) return _warp_mfma(A_local_buf, B_local_buf, C_local_buf)
def stmatrix(self, C_local_buf, C_buf, pid_m=None, pid_n=None): def stmatrix(self, C_local_buf, C_buf, pid_m=None, pid_n=None):
block_row_warps = self.block_row_warps block_row_warps = self.block_row_warps
...@@ -356,8 +401,7 @@ class MatrixCoreIntrinEmitter: ...@@ -356,8 +401,7 @@ class MatrixCoreIntrinEmitter:
warp_rows = self.warp_rows warp_rows = self.warp_rows
warp_cols = self.warp_cols warp_cols = self.warp_cols
local_size_out = self.local_size_out local_size_out = self.local_size_out
current_frame = T.KernelLaunchFrame.Current() thread_binding = self.get_thread_binding()
thread_binding = current_frame.get_thread_binding()
is_global = pid_m is not None and pid_n is not None is_global = pid_m is not None and pid_n is not None
BLOCK_M = block_row_warps * warp_rows BLOCK_M = block_row_warps * warp_rows
BLOCK_N = block_col_warps * warp_cols BLOCK_N = block_col_warps * warp_cols
...@@ -366,7 +410,7 @@ class MatrixCoreIntrinEmitter: ...@@ -366,7 +410,7 @@ class MatrixCoreIntrinEmitter:
assert C_buf_dims in {2, 4}, "C_buf should be 2D or 4D" assert C_buf_dims in {2, 4}, "C_buf should be 2D or 4D"
# STS # STS
# MMA Store must be in simulated instead of TVM Intrins # MFMA Store must be in simulated instead of TVM Intrins
# As TVM Intrins is like a hack that the threadIdx.x should be always # As TVM Intrins is like a hack that the threadIdx.x should be always
# equal to the warp_size # equal to the warp_size
@T.macro @T.macro
...@@ -400,6 +444,217 @@ class MatrixCoreIntrinEmitter: ...@@ -400,6 +444,217 @@ class MatrixCoreIntrinEmitter:
thread_binding) if is_global else _warp_stmatrix_shared( thread_binding) if is_global else _warp_stmatrix_shared(
C_local_buf, C_buf, thread_binding) C_local_buf, C_buf, thread_binding)
def make_mfma_load_layout(self,
local_buf: Buffer,
matrix: Literal["A", "B"] = "A") -> T.Fragment:
"""
Create a layout function for storing MFMA results into a fragment buffer.
Parameters
----------
local_buf : tir.Buffer
The local buffer representing a fragment of a matrix.
Returns
-------
T.Fragment
A fragment object that describes how threads and indices
in `local_buf` are laid out.
Raises
------
AssertionError
If `local_buf` is not detected to be a fragment buffer.
"""
from tilelang.utils import is_fragment
assert matrix in ["A", "B"], "matrix should be either A or B"
matrix_is_a: bool = matrix == "A"
matrix_is_b: bool = matrix == "B"
transposed = self.a_transposed if matrix_is_a else self.b_transposed
# s represents spatial axis
# r represents reduction axis
# sr represents the two dims are spatial + reduction
# rs represents the two dims are reduction + spatial
# sr also can represent a non-transposed basic layout
# then rs also can represent a transposed basic layout
transform_func_sr_a: Callable = None
transform_func_sr_b: Callable = None
k_dim = self.k_dim * self.k_pack
if k_dim == 4:
transform_func_sr_a = shared_16x4_to_local_64x1_layout_A
transform_func_sr_b = shared_16x4_to_local_64x1_layout_A
elif k_dim == 16:
transform_func_sr_a = shared_16x16_to_local_64x4_layout_A
transform_func_sr_b = shared_16x16_to_local_64x4_layout_A
elif k_dim == 32:
transform_func_sr_a = shared_16x32_to_local_64x8_layout_A
transform_func_sr_b = shared_16x32_to_local_64x8_layout_A
elif k_dim == 64:
transform_func_sr_a = shared_16x64_to_local_64x16_layout_A
transform_func_sr_b = shared_16x64_to_local_64x16_layout_A
else:
raise ValueError("k_dim must be 4 or 16 or 32 or 64 currently")
is_sr_conditions = [False]
is_sr_conditions.append(matrix_is_a and not transposed)
is_sr_conditions.append(matrix_is_b and transposed)
is_sr_axis_order = any(is_sr_conditions)
transform_func: Callable = None
if matrix_is_a:
transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a(
j, i)
elif matrix_is_b:
transform_func = transform_func_sr_b if is_sr_axis_order else lambda i, j: transform_func_sr_b(
j, i)
else:
raise ValueError(f"Unsupported matrix {matrix}")
assert is_fragment(local_buf), f"local_buf must be a fragment, but got {local_buf.scope()}"
if matrix_is_a:
micro_size_s, micro_size_r = self.micro_size_x, self.micro_size_k
else:
micro_size_r, micro_size_s = self.micro_size_k, self.micro_size_y
block_row_warps, block_col_warps = (
self.block_row_warps,
self.block_col_warps,
)
inverse_mfma_load_layout = IndexMap.from_func(transform_func, index_dtype="int32")
def forward_thread(i: int, j: int) -> int:
"""
Given the row index `i` and column index `j` in the fragment,
"""
lane_id, _ = inverse_mfma_load_layout.map_indices([i, j])
return lane_id
def forward_index(i: int, j: int) -> int:
"""
Given the row index `i` and column index `j` in the fragment,
"""
_, local_id = inverse_mfma_load_layout.map_indices([i, j])
return local_id
base_fragment = T.Fragment(
[micro_size_s, micro_size_r] if is_sr_axis_order else [micro_size_r, micro_size_s],
forward_thread_fn=forward_thread,
forward_index_fn=forward_index,
)
warp_rows, warp_cols = self.warp_rows, self.warp_cols
chunk = self.chunk
warp_s = warp_rows if matrix_is_a else warp_cols
warp_r = chunk // micro_size_r
block_s = block_row_warps if matrix_is_a else block_col_warps
replicate = block_col_warps if matrix_is_a else block_row_warps
if is_sr_axis_order:
warp_fragment = base_fragment.repeat([warp_s, warp_r],
repeat_on_thread=False,
lower_dim_first=False)
if matrix_is_a:
block_fragment = warp_fragment.repeat([block_s, 1],
repeat_on_thread=True,
lower_dim_first=True).replicate(replicate)
elif matrix_is_b:
block_fragment = warp_fragment.replicate(replicate).repeat([block_s, 1],
repeat_on_thread=True,
lower_dim_first=True)
else:
raise ValueError(f"Unsupported matrix type {matrix}")
else:
warp_fragment = base_fragment.repeat([warp_r, warp_s],
repeat_on_thread=False,
lower_dim_first=True)
if matrix_is_a:
block_fragment = warp_fragment.repeat([1, block_s],
repeat_on_thread=True,
lower_dim_first=True).replicate(replicate)
elif matrix_is_b:
block_fragment = warp_fragment.replicate(replicate).repeat([1, block_s],
repeat_on_thread=True,
lower_dim_first=True)
else:
raise ValueError(f"Unsupported matrix type {matrix}")
return block_fragment
def make_mfma_store_layout(self, local_buf: Buffer) -> T.Fragment:
"""
Create a layout function for storing MFMA results into a fragment buffer.
Parameters
----------
local_buf : tir.Buffer
The local buffer representing a fragment of a matrix.
Returns
-------
T.Fragment
A fragment object that describes how threads and indices
in `local_buf` are laid out.
Raises
------
AssertionError
If `local_buf` is not detected to be a fragment buffer.
"""
from tilelang.utils import is_fragment
shape = local_buf.shape
inverse_mfma_store_layout = self.get_store_index_map(inverse=True)
assert is_fragment(local_buf), "local_buf must be a fragment"
micro_size_x, micro_size_y = self.micro_size_x, self.micro_size_y
local_size_out = self.local_size_out
block_row_warps, block_col_warps = self.block_row_warps, self.block_col_warps
warp_rows, warp_cols = self.warp_rows, self.warp_cols
warp_size = self.WARP_SIZE
is_m_first = self.is_m_first
def forward_thread(i: int, j: int) -> int:
"""
Given the row index `i` and column index `j` in the fragment,
map them to a thread index according to `inverse_mfma_store_layout`.
"""
# the upper bounds of i and j are block_row_warps * warp_rows * micro_size_x and block_col_warps * warp_cols * micro_size_y
# the upper bounds of block_row_warps and block_col_warps are warp_rows and warp_cols
block_i, block_j = (i // micro_size_x) // warp_rows, (j // micro_size_y) // warp_cols
# upper bounds of mfma_i and mfma_j are micro_size_x and micro_size_y
mfma_i, mfma_j = i % micro_size_x, j % micro_size_y
lane_id, _ = inverse_mfma_store_layout.map_indices([mfma_i, mfma_j])
if is_m_first:
thread_id = block_i * (block_col_warps * warp_cols) + block_j * warp_size + lane_id
else:
thread_id = block_j * (block_row_warps * warp_size) + block_i * warp_size + lane_id
return thread_id
def forward_index(i: int, j: int) -> int:
"""
Given the row index `i` and column index `j` in the fragment,
map them to a local index in a single thread according
to `inverse_mfma_store_layout`.
"""
# the upper bounds of i and j are block_row_warps * warp_rows * micro_size_x and block_col_warps * warp_cols * micro_size_y
# the upper bounds of warp_i and warp_j are warp_rows and warp_cols
warp_i, warp_j = (i // micro_size_x) % warp_rows, (j // micro_size_y) % warp_cols
# upper bounds of mfma_i and mfma_j are micro_size_x and micro_size_y
mfma_i, mfma_j = i % micro_size_x, j % micro_size_y
_, local_id = inverse_mfma_store_layout.map_indices([mfma_i, mfma_j])
return warp_i * (warp_cols * local_size_out) + warp_j * local_size_out + local_id
return T.Fragment(
shape,
forward_thread_fn=forward_thread,
forward_index_fn=forward_index,
)
class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter): class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter):
...@@ -421,34 +676,27 @@ class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter): ...@@ -421,34 +676,27 @@ class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter):
is_m_first: bool | None = False, is_m_first: bool | None = False,
a_preshuffle: bool | None = False, a_preshuffle: bool | None = False,
b_preshuffle: bool | None = False, b_preshuffle: bool | None = False,
thread_var: Var | None = None,
): ):
super().__init__(
self.a_dtype = a_dtype a_dtype=a_dtype,
self.b_dtype = b_dtype b_dtype=b_dtype,
self.accum_dtype = accum_dtype accum_dtype=accum_dtype,
self.a_transposed = a_transposed a_transposed=a_transposed,
self.b_transposed = b_transposed b_transposed=b_transposed,
# Hint Information block_row_warps=block_row_warps,
self.block_row_warps = block_row_warps block_col_warps=block_col_warps,
self.block_col_warps = block_col_warps warp_row_tiles=warp_row_tiles,
self.warp_row_tiles = warp_row_tiles warp_col_tiles=warp_col_tiles,
self.warp_col_tiles = warp_col_tiles chunk=chunk,
self.chunk = chunk reduce_k=reduce_k,
self._initialize_k_dim(a_dtype) num_elems_per_byte=num_elems_per_byte,
self._initialize_abbrev(a_dtype, b_dtype, accum_dtype) k_pack=k_pack,
self._initialize_local_size(self.M_DIM, self.N_DIM, self.k_dim, self.WARP_SIZE) is_m_first=is_m_first,
self._initialize_mfma_prefix(self.k_dim) thread_var=thread_var,
self._initialize_micro_size(self.M_DIM, self.N_DIM, self.k_dim) )
self._initialize_k_pack(k_pack)
self._initialize_is_m_first(is_m_first)
self._initialize_preshuffle(a_preshuffle, b_preshuffle) self._initialize_preshuffle(a_preshuffle, b_preshuffle)
self.warp_rows = warp_row_tiles // self.micro_size_x
self.warp_cols = warp_col_tiles // self.micro_size_y
self.reduce_k = reduce_k
self.threads = (self.WARP_SIZE * (block_row_warps * block_col_warps) * reduce_k)
self.num_elems_per_byte = num_elems_per_byte
def _initialize_preshuffle(self, a_preshuffle: bool, b_preshuffle: bool): def _initialize_preshuffle(self, a_preshuffle: bool, b_preshuffle: bool):
if a_preshuffle is not None: if a_preshuffle is not None:
self.a_preshuffle = a_preshuffle self.a_preshuffle = a_preshuffle
......
...@@ -45,6 +45,12 @@ def mma_store_32x8_to_shared_16x16_layout(thread_id, local_id): ...@@ -45,6 +45,12 @@ def mma_store_32x8_to_shared_16x16_layout(thread_id, local_id):
return row, col return row, col
def mma_store_32x2_to_shared_8x8_layout_fp64(thread_id, local_id):
row = thread_id // 4
col = (thread_id % 4) * 2 + local_id
return row, col
# sr represents spatial + reduction layout # sr represents spatial + reduction layout
# the first axis is spatial while the second axis is reduction # the first axis is spatial while the second axis is reduction
# mma.sync matrix A layout, if wanna trans, please apply map_indices # mma.sync matrix A layout, if wanna trans, please apply map_indices
......
...@@ -3,13 +3,14 @@ import tilelang.language as T ...@@ -3,13 +3,14 @@ import tilelang.language as T
from typing import Literal, Callable from typing import Literal, Callable
from tilelang.common import TransformKind from tilelang.common import TransformKind
from tvm import DataType from tvm import DataType
from tvm.tir import PrimExpr, IndexMap, Buffer, Var from tvm.tir import PrimExpr, IndexMap, Buffer, Var, BufferRegion
from tilelang import tvm as tvm
from tvm.runtime import convert from tvm.runtime import convert
from .utils import ( from .utils import (
mma_store_index_map, mma_store_index_map,
get_ldmatrix_offset, get_ldmatrix_offset,
) )
from tilelang.utils import is_fragment from tilelang.utils import is_fragment, to_buffer_region
from tilelang.intrinsics.mma_layout import ( from tilelang.intrinsics.mma_layout import (
shared_16x8_to_mma_32x4_layout_sr_a, shared_16x8_to_mma_32x4_layout_sr_a,
shared_16x8_to_mma_32x4_layout_sr_b, shared_16x8_to_mma_32x4_layout_sr_b,
...@@ -40,6 +41,7 @@ class TensorCoreIntrinEmitter: ...@@ -40,6 +41,7 @@ class TensorCoreIntrinEmitter:
"float16": "fp16", "float16": "fp16",
"bfloat16": "bf16", "bfloat16": "bf16",
"float32": "fp32", "float32": "fp32",
"float64": "fp64",
"int8": "int8", "int8": "int8",
"int32": "int32", "int32": "int32",
"float8_e4m3": "e4m3", "float8_e4m3": "e4m3",
...@@ -78,6 +80,11 @@ class TensorCoreIntrinEmitter: ...@@ -78,6 +80,11 @@ class TensorCoreIntrinEmitter:
self.warp_col_tiles = warp_col_tiles self.warp_col_tiles = warp_col_tiles
self.chunk = chunk self.chunk = chunk
self._initialize_k_dim(a_dtype) self._initialize_k_dim(a_dtype)
# For FP64, MMA shape is m8n8k4; adjust instance dims early
if DataType(a_dtype).bits == 64:
# Override default M/N dims for fp64 MMA
self.M_DIM = 8
# n_dim will be set to 8 in _initialize_micro_size via k_dim==4
self._initialize_abbrev(a_dtype, b_dtype, accum_dtype) self._initialize_abbrev(a_dtype, b_dtype, accum_dtype)
self._initialize_micro_size(self.M_DIM, self.k_dim) self._initialize_micro_size(self.M_DIM, self.k_dim)
self._initialize_local_size(self.M_DIM, self.n_dim, self.k_dim, self.WARP_SIZE) self._initialize_local_size(self.M_DIM, self.n_dim, self.k_dim, self.WARP_SIZE)
...@@ -105,12 +112,21 @@ class TensorCoreIntrinEmitter: ...@@ -105,12 +112,21 @@ class TensorCoreIntrinEmitter:
self.local_size_out = (m_dim * n_dim) // warp_size self.local_size_out = (m_dim * n_dim) // warp_size
def _initialize_abbrev(self, a_dtype, b_dtype, accum_dtype): def _initialize_abbrev(self, a_dtype, b_dtype, accum_dtype):
self.a_dtype_abbrv = self.dtype_abbrv[a_dtype] self.a_dtype_abbrv = self._get_dtype_abbrv(a_dtype)
self.b_dtype_abbrv = self.dtype_abbrv[b_dtype] self.b_dtype_abbrv = self._get_dtype_abbrv(b_dtype)
self.accum_dtype_abbrv = self.dtype_abbrv[accum_dtype] self.accum_dtype_abbrv = self._get_dtype_abbrv(accum_dtype)
def _get_dtype_abbrv(self, dtype: str) -> str:
try:
return self.dtype_abbrv[dtype]
except KeyError as err:
raise ValueError(f"Unsupported dtype: {dtype}") from err
def _initialize_mma_prefix(self, k_dim: int = 16): def _initialize_mma_prefix(self, k_dim: int = 16):
if k_dim == 8: if k_dim == 4:
# fp64
self.mma_prefix = "m8n8k4"
elif k_dim == 8:
# typically used for tfloat32 # typically used for tfloat32
self.mma_prefix = "m16n8k8" self.mma_prefix = "m16n8k8"
elif k_dim == 16: elif k_dim == 16:
...@@ -125,22 +141,31 @@ class TensorCoreIntrinEmitter: ...@@ -125,22 +141,31 @@ class TensorCoreIntrinEmitter:
def _initialize_micro_size(self, m_dim: int = 16, k_dim: int = 16): def _initialize_micro_size(self, m_dim: int = 16, k_dim: int = 16):
warp_row_tiles = self.warp_row_tiles warp_row_tiles = self.warp_row_tiles
warp_col_tiles = self.warp_col_tiles warp_col_tiles = self.warp_col_tiles
assert warp_row_tiles >= 16, f"warp_row_tiles must be greater than 16, got {warp_row_tiles}" # For fp64 (k_dim==4), micro tile is 8x8, otherwise keep 16x{8|16}
assert warp_row_tiles % 16 == 0, f"warp_row_tiles must be divisible by 16, got {warp_row_tiles}" if k_dim == 4:
assert warp_col_tiles >= 8, f"warp_col_tiles must be greater than 8, got {warp_col_tiles}" # fp64 path: m_dim must be 8, n_dim 8
assert warp_col_tiles % 8 == 0, f"warp_col_tiles must be divisible by 8, got {warp_col_tiles}" assert m_dim == 8, f"For fp64 MMA, m_dim must be 8, got {m_dim}"
self.warp_rows = warp_row_tiles // m_dim
if warp_col_tiles % 16 == 0:
self.n_dim = 16
self.micro_size_y = 16
self.warp_cols = warp_col_tiles // 16
else:
# must be divisible by 8
self.n_dim = 8 self.n_dim = 8
self.micro_size_y = 8 self.micro_size_y = 8
self.warp_rows = warp_row_tiles // m_dim
self.warp_cols = warp_col_tiles // 8 self.warp_cols = warp_col_tiles // 8
else:
assert warp_row_tiles >= 16, f"warp_row_tiles must be greater than 16, got {warp_row_tiles}"
assert warp_row_tiles % 16 == 0, f"warp_row_tiles must be divisible by 16, got {warp_row_tiles}"
assert warp_col_tiles >= 8, f"warp_col_tiles must be greater than 8, got {warp_col_tiles}"
assert warp_col_tiles % 8 == 0, f"warp_col_tiles must be divisible by 8, got {warp_col_tiles}"
self.warp_rows = warp_row_tiles // m_dim
if warp_col_tiles % 16 == 0:
self.n_dim = 16
self.micro_size_y = 16
self.warp_cols = warp_col_tiles // 16
else:
# must be divisible by 8
self.n_dim = 8
self.micro_size_y = 8
self.warp_cols = warp_col_tiles // 8
self.micro_size_x = m_dim self.micro_size_x = m_dim
self.micro_size_k = k_dim self.micro_size_k = k_dim
...@@ -158,8 +183,12 @@ class TensorCoreIntrinEmitter: ...@@ -158,8 +183,12 @@ class TensorCoreIntrinEmitter:
return self.thread_var return self.thread_var
def get_store_index_map(self, inverse: bool = False) -> IndexMap: def get_store_index_map(self, inverse: bool = False) -> IndexMap:
from .utils import mma_store_index_map, mma_store_index_map_fp64
warp_size, local_size_c = self.WARP_SIZE, self.local_size_out warp_size, local_size_c = self.WARP_SIZE, self.local_size_out
index_map = IndexMap.from_func(mma_store_index_map, index_dtype="int32") if DataType(self.accum_dtype).bits == 64:
index_map = IndexMap.from_func(mma_store_index_map_fp64, index_dtype="int32")
else:
index_map = IndexMap.from_func(mma_store_index_map, index_dtype="int32")
if not inverse: if not inverse:
return index_map return index_map
inverse_index_map = index_map.inverse([warp_size, local_size_c]) inverse_index_map = index_map.inverse([warp_size, local_size_c])
...@@ -199,9 +228,47 @@ class TensorCoreIntrinEmitter: ...@@ -199,9 +228,47 @@ class TensorCoreIntrinEmitter:
def ldmatrix_a(self, def ldmatrix_a(self,
A_local_buf: Buffer, A_local_buf: Buffer,
A_shared_buf: Buffer, A_shared_buf: Buffer | BufferRegion,
ki: PrimExpr, ki: PrimExpr,
rk: PrimExpr | None = 0): rk: PrimExpr | None = 0):
# Fast path for fp64: no ldmatrix support, do direct per-lane loads
if DataType(self.a_dtype).bits == 64:
warp_row_tiles = self.warp_row_tiles
warp_rows = self.warp_rows
chunk = self.chunk
micro_size_x = self.micro_size_x # 8
micro_size_k = self.micro_size_k # 4
local_size_a = self.local_size_a # 1
a_transposed = self.a_transposed
thread_binding = self.get_thread_binding()
# legalize shared buffer to region
A_region = to_buffer_region(A_shared_buf)
A_buf = A_region.buffer
A_base0 = A_region.region[-2].min
A_base1 = A_region.region[-1].min
@T.macro
def _warp_ld_a_fp64(
A_local_buf,
A_shared_buf,
ki,
thread_binding,
rk=0,
):
tx, _, warp_m = self.extract_thread_binding(thread_binding)
for i in T.serial(warp_rows):
wi = warp_m * warp_row_tiles + i * micro_size_x
wk = rk * chunk + ki * micro_size_k
mi = tx // micro_size_k
mk = tx % micro_size_k
if a_transposed:
A_local_buf[i * local_size_a] = A_buf[A_base0 + wk + mk, A_base1 + wi + mi]
else:
A_local_buf[i * local_size_a] = A_buf[A_base0 + wi + mi, A_base1 + wk + mk]
return _warp_ld_a_fp64(A_local_buf, A_region, ki, thread_binding, rk)
warp_row_tiles = self.warp_row_tiles warp_row_tiles = self.warp_row_tiles
warp_rows = self.warp_rows warp_rows = self.warp_rows
chunk = self.chunk chunk = self.chunk
...@@ -226,6 +293,13 @@ class TensorCoreIntrinEmitter: ...@@ -226,6 +293,13 @@ class TensorCoreIntrinEmitter:
thread_binding = self.get_thread_binding() thread_binding = self.get_thread_binding()
# legalize shared buffer to region
A_region = to_buffer_region(A_shared_buf)
A_buf = A_region.buffer
A_base0 = A_region.region[-2].min
A_base1 = A_region.region[-1].min
A_stride_last = A_buf.shape[-1]
@T.macro @T.macro
def _warp_ldmatrix_a( def _warp_ldmatrix_a(
A_local_buf, A_local_buf,
...@@ -234,14 +308,16 @@ class TensorCoreIntrinEmitter: ...@@ -234,14 +308,16 @@ class TensorCoreIntrinEmitter:
thread_binding, thread_binding,
rk=0, rk=0,
): ):
stride = A_shared_buf.shape[-1] stride = A_stride_last
tx, _, warp_m = self.extract_thread_binding(thread_binding) tx, _, warp_m = self.extract_thread_binding(thread_binding)
trans = self.a_transposed trans = self.a_transposed
for i in T.serial(warp_rows): for i in T.serial(warp_rows):
# Assign A_shared_buf_elem # Assign A_shared_buf_elem
wi, wk = warp_m * warp_row_tiles + i * micro_size_x, rk * chunk + ki * micro_size_k wi, wk = warp_m * warp_row_tiles + i * micro_size_x, rk * chunk + ki * micro_size_k
A_shared_buf_elem = A_shared_buf[wk, wi] if a_transposed else A_shared_buf[wi, wk] A_shared_buf_elem = A_buf[A_base0 + wk,
A_base1 + wi] if a_transposed else A_buf[A_base0 + wi,
A_base1 + wk]
if ldmatrix_available: if ldmatrix_available:
T.ptx_ldmatrix( T.ptx_ldmatrix(
...@@ -257,15 +333,59 @@ class TensorCoreIntrinEmitter: ...@@ -257,15 +333,59 @@ class TensorCoreIntrinEmitter:
else: else:
for j in T.serial(local_size_a): for j in T.serial(local_size_a):
mi, mk = mma_load_layout(tx, j) mi, mk = mma_load_layout(tx, j)
A_local_buf[i * local_size_a + j] = A_shared_buf[wk + mk, wi + mi] if a_transposed:
A_local_buf[i * local_size_a + j] = A_buf[A_base0 + wk + mk,
A_base1 + wi + mi]
else:
A_local_buf[i * local_size_a + j] = A_buf[A_base0 + wi + mi,
A_base1 + wk + mk]
return _warp_ldmatrix_a(A_local_buf, A_shared_buf, ki, thread_binding, rk) return _warp_ldmatrix_a(A_local_buf, A_region, ki, thread_binding, rk)
def ldmatrix_b(self, def ldmatrix_b(self,
B_local_buf: Buffer, B_local_buf: Buffer,
B_shared_buf: Buffer, B_shared_buf: Buffer | BufferRegion,
ki: PrimExpr, ki: PrimExpr,
rk: PrimExpr | None = 0): rk: PrimExpr | None = 0):
# Fast path for fp64: no ldmatrix support, do direct per-lane loads
if DataType(self.b_dtype).bits == 64:
warp_col_tiles = self.warp_col_tiles
warp_cols = self.warp_cols
chunk = self.chunk
micro_size_y = self.micro_size_y # 8
micro_size_k = self.micro_size_k # 4
local_size_b = self.local_size_b # 1
b_transposed = self.b_transposed
thread_binding = self.get_thread_binding()
# legalize shared buffer to region
B_region = to_buffer_region(B_shared_buf)
B_buf = B_region.buffer
B_base0 = B_region.region[-2].min
B_base1 = B_region.region[-1].min
@T.macro
def _warp_ld_b_fp64(
B_local_buf,
B_shared_buf,
ki,
thread_binding,
rk=0,
):
tx, warp_n, _ = self.extract_thread_binding(thread_binding)
for j in T.serial(warp_cols):
wi = warp_n * warp_col_tiles + j * micro_size_y
wk = rk * chunk + ki * micro_size_k
mi = tx // micro_size_k
mk = tx % micro_size_k
if b_transposed:
B_local_buf[j * local_size_b] = B_buf[B_base0 + wi + mi, B_base1 + wk + mk]
else:
B_local_buf[j * local_size_b] = B_buf[B_base0 + wk + mk, B_base1 + wi + mi]
return _warp_ld_b_fp64(B_local_buf, B_region, ki, thread_binding, rk)
warp_col_tiles = self.warp_col_tiles warp_col_tiles = self.warp_col_tiles
warp_cols = self.warp_cols warp_cols = self.warp_cols
chunk = self.chunk chunk = self.chunk
...@@ -275,6 +395,13 @@ class TensorCoreIntrinEmitter: ...@@ -275,6 +395,13 @@ class TensorCoreIntrinEmitter:
b_dtype = self.b_dtype b_dtype = self.b_dtype
b_transposed = self.b_transposed b_transposed = self.b_transposed
thread_binding = self.get_thread_binding() thread_binding = self.get_thread_binding()
# legalize shared buffer to region
B_region = to_buffer_region(B_shared_buf)
B_buf = B_region.buffer
B_base0 = B_region.region[-2].min
B_base1 = B_region.region[-1].min
B_stride_last = B_buf.shape[-1]
replicate_b = (self.n_dim == 16) replicate_b = (self.n_dim == 16)
# ldmatrix cannot be used for int8 + trans case. # ldmatrix cannot be used for int8 + trans case.
ldmatrix_available = not (DataType(b_dtype).bits != 16 and not b_transposed) ldmatrix_available = not (DataType(b_dtype).bits != 16 and not b_transposed)
...@@ -298,7 +425,7 @@ class TensorCoreIntrinEmitter: ...@@ -298,7 +425,7 @@ class TensorCoreIntrinEmitter:
thread_binding, thread_binding,
rk=0, rk=0,
): ):
stride = B_shared_buf.shape[-1] stride = B_stride_last
tx, warp_n, _ = self.extract_thread_binding(thread_binding) tx, warp_n, _ = self.extract_thread_binding(thread_binding)
trans = not b_transposed trans = not b_transposed
...@@ -310,8 +437,9 @@ class TensorCoreIntrinEmitter: ...@@ -310,8 +437,9 @@ class TensorCoreIntrinEmitter:
) )
if ldmatrix_available: if ldmatrix_available:
B_shared_buf_elem = B_shared_buf[wi, wk] if b_transposed else B_shared_buf[wk, B_shared_buf_elem = B_buf[B_base0 + wi,
wi] B_base1 + wk] if b_transposed else B_buf[B_base0 + wk,
B_base1 + wi]
T.ptx_ldmatrix( T.ptx_ldmatrix(
b_dtype, b_dtype,
...@@ -329,7 +457,12 @@ class TensorCoreIntrinEmitter: ...@@ -329,7 +457,12 @@ class TensorCoreIntrinEmitter:
# must be transposed. # must be transposed.
for j in T.serial(local_size_b): for j in T.serial(local_size_b):
mi, mk = mma_load_layout(tx, j) mi, mk = mma_load_layout(tx, j)
B_local_buf[i * local_size_b + j] = B_shared_buf[wk + mk, wi + mi] if b_transposed:
B_local_buf[i * local_size_b + j] = B_buf[B_base0 + wi + mi,
B_base1 + wk + mk]
else:
B_local_buf[i * local_size_b + j] = B_buf[B_base0 + wk + mk,
B_base1 + wi + mi]
return _warp_ldmatrix_b(B_local_buf, B_shared_buf, ki, thread_binding, rk) return _warp_ldmatrix_b(B_local_buf, B_shared_buf, ki, thread_binding, rk)
...@@ -617,8 +750,10 @@ class TensorCoreIntrinEmitter: ...@@ -617,8 +750,10 @@ class TensorCoreIntrinEmitter:
from tilelang.utils import is_fragment from tilelang.utils import is_fragment
shape = local_buf.shape shape = local_buf.shape
assert is_fragment(
local_buf), f"local_buf {local_buf} must be a fragment, but got {local_buf.scope()}"
inverse_mma_store_layout = self.get_store_index_map(inverse=True) inverse_mma_store_layout = self.get_store_index_map(inverse=True)
assert is_fragment(local_buf), "local_buf must be a fragment"
micro_size_x, micro_size_y = self.micro_size_x, self.micro_size_y micro_size_x, micro_size_y = self.micro_size_x, self.micro_size_y
local_size_out = self.local_size_out local_size_out = self.local_size_out
block_row_warps, block_col_warps = self.block_row_warps, self.block_col_warps block_row_warps, block_col_warps = self.block_row_warps, self.block_col_warps
......
from __future__ import annotations
def shared_16x4_to_mma_a_32x4_layout(row, col, rep):
tid = (row % 4) + 16 * ((row // 4) % 2) + 4 * (row // 8) + 8 * rep
local_id = col
return tid, local_id
def shared_4x16_to_mma_b_32x4_layout(row, col, rep):
thread_id = row + 8 * col // 4 + 4 * rep
local_id = col % 4
return thread_id, local_id
def shared_16x4_to_mma_b_32x4_layout_trans(row, col, rep):
thread_id = row % 4 + 4 * rep + 8 * ((row % 8) // 4) + 16 * (row // 8)
local_id = col
return thread_id, local_id
def mma_32x8_to_shared_16x16_layout_fp32(thread_id, local_id):
row = (thread_id % 2) + (
(local_id // 2 % 2) * 2) + 4 * (thread_id // 16) + (thread_id % 16 // 4) % 2 * 8
col = (thread_id % 4 // 2) * 2 + (thread_id % 16 // 8) * 4 + (local_id %
2) + (local_id // 4) * 8
return row, col
def mma_32x8_to_shared_16x16_layout_fp16(thread_id, local_id):
row = (thread_id % 4) + (thread_id // 16) * 4 + (thread_id % 8) // 4 * 8
col = local_id % 4 + ((thread_id % 16) // 8) * 4 + (local_id // 4) * 8
return row, col
def mma_load_a_32x4_to_shared_16x4_layout(thread_id, local_id):
row = (thread_id % 4) + (4 * (((thread_id // 16 + thread_id % 16 // 4 * 2)) % 4))
col = local_id
return row, col
def mma_load_b_32x4_to_shared_16x4_layout_trans(thread_id, local_id):
row = (thread_id % 4) + 8 * (thread_id // 16) + 4 * ((thread_id // 8) % 2)
col = local_id
return row, col
def mma_load_b_32x4_to_shared_4x16_layout(thread_id, local_id):
row = thread_id % 4
col = local_id + (4 * (thread_id // 8))
return row, col
from __future__ import annotations
import tilelang.language as T
from typing import Literal, Callable
from tvm import DataType
from tvm.tir import PrimExpr, IndexMap, Buffer, Var, BufferRegion
from tilelang import tvm as tvm
from tvm.runtime import convert
from tilelang.utils import is_fragment, to_buffer_region
from tilelang.intrinsics.mma_sm70_layout import (
shared_16x4_to_mma_a_32x4_layout,
shared_4x16_to_mma_b_32x4_layout,
shared_16x4_to_mma_b_32x4_layout_trans,
mma_32x8_to_shared_16x16_layout_fp32,
mma_32x8_to_shared_16x16_layout_fp16,
mma_load_a_32x4_to_shared_16x4_layout,
mma_load_b_32x4_to_shared_16x4_layout_trans,
mma_load_b_32x4_to_shared_4x16_layout,
)
lift = convert
class TensorCoreIntrinEmitter:
"""
To eliminate Python syntax within TIR Macro.
"""
M_DIM = 16
# use lowercase as n_dim can be dynamic
# the smallest instructions can be m16n8k16, so the n_dim can also be 8
n_dim = 16
WARP_SIZE = 32
HALF_WARP_SIZE = WARP_SIZE // 2
dtype_abbrv = {
"float16": "fp16",
"bfloat16": "bf16",
"float32": "fp32",
"int8": "int8",
"int32": "int32",
"float8_e4m3": "e4m3",
"float8_e5m2": "e5m2",
}
# Represent the thread binding in the form of (tx, warp_n, warp_m)
is_m_first = False
def __init__(
self,
a_dtype: str = "float16",
b_dtype: str = "float16",
accum_dtype: str = "float16",
a_transposed: bool = False,
b_transposed: bool = False,
block_row_warps: int = 2,
block_col_warps: int = 2,
warp_row_tiles: int = 8,
warp_col_tiles: int = 8,
chunk: int = 16,
reduce_k: int = 1,
num_elems_per_byte: int = 1,
is_m_first: bool | None = False,
thread_var: Var | None = None,
):
self.a_dtype = a_dtype
self.b_dtype = b_dtype
self.accum_dtype = accum_dtype
self.a_transposed = a_transposed
self.b_transposed = b_transposed
# Hint Information
self.block_row_warps = block_row_warps
self.block_col_warps = block_col_warps
self.warp_row_tiles = warp_row_tiles
self.warp_col_tiles = warp_col_tiles
self.chunk = chunk
self._initialize_k_dim(a_dtype)
self._initialize_abbrev(a_dtype, b_dtype, accum_dtype)
self._initialize_micro_size(self.M_DIM, self.k_dim)
self._initialize_local_size(self.M_DIM, self.n_dim, self.k_dim)
self._initialize_mma_prefix(self.k_dim)
self._initialize_is_m_first(is_m_first)
self.reduce_k = reduce_k
self.threads = self.WARP_SIZE * (block_row_warps * block_col_warps) * reduce_k
self.num_elems_per_byte = num_elems_per_byte
self.thread_var = thread_var
if self.warp_rows == 0 or self.warp_cols == 0:
raise ValueError(
f"Invalid threads configuration for this tile shape, {self.warp_rows} x {self.warp_cols} with threads {self.threads}"
)
def _initialize_k_dim(self, a_dtype="float16"):
self.k_dim = 4
def _initialize_local_size(self, m_dim=16, n_dim=16, k_dim=16):
self.local_size_a = (m_dim * k_dim) // self.HALF_WARP_SIZE
self.local_size_b = (n_dim * k_dim) // self.HALF_WARP_SIZE
self.local_size_out = (m_dim * n_dim) // self.WARP_SIZE
def _initialize_abbrev(self, a_dtype, b_dtype, accum_dtype):
self.a_dtype_abbrv = self._get_dtype_abbrv(a_dtype)
self.b_dtype_abbrv = self._get_dtype_abbrv(b_dtype)
self.accum_dtype_abbrv = self._get_dtype_abbrv(accum_dtype)
def _get_dtype_abbrv(self, dtype: str) -> str:
try:
return self.dtype_abbrv[dtype]
except KeyError as err:
raise ValueError(f"Unsupported dtype: {dtype}") from err
def _initialize_mma_prefix(self, k_dim: int = 16):
if k_dim == 4:
# typically used for float16
self.mma_prefix = "m16n16k4"
else:
raise ValueError(f"Unsupported k_dim: {k_dim}")
def _initialize_micro_size(self, m_dim: int = 16, k_dim: int = 16):
warp_row_tiles = self.warp_row_tiles
warp_col_tiles = self.warp_col_tiles
assert warp_row_tiles >= 16, f"warp_row_tiles must be greater than 16, got {warp_row_tiles}"
assert warp_row_tiles % 16 == 0, f"warp_row_tiles must be divisible by 16, got {warp_row_tiles}"
assert warp_col_tiles >= 16, f"warp_col_tiles must be greater than 16, got {warp_col_tiles}"
assert warp_col_tiles % 16 == 0, f"warp_col_tiles must be divisible by 16, got {warp_col_tiles}"
self.warp_rows = warp_row_tiles // m_dim
self.n_dim = 16
self.micro_size_y = 16
self.warp_cols = warp_col_tiles // 16
self.micro_size_x = m_dim
self.micro_size_k = k_dim
def _initialize_is_m_first(self, is_m_first: bool | None = False):
if is_m_first is not None:
self.is_m_first = is_m_first
def get_thread_binding(self):
if self.thread_var is None:
current_frame = T.KernelLaunchFrame.Current()
assert current_frame is not None, "Must be called in a T.Kernel Frame"
return current_frame.get_thread_binding()
else:
return self.thread_var
def get_store_index_map(self, inverse: bool = False) -> IndexMap:
warp_size, local_size_c = self.WARP_SIZE, self.local_size_out
index_map = IndexMap.from_func(
mma_32x8_to_shared_16x16_layout_fp32
if self.accum_dtype == "float32" else mma_32x8_to_shared_16x16_layout_fp16,
index_dtype="int32")
if not inverse:
return index_map
inverse_index_map = index_map.inverse([warp_size, local_size_c])
return inverse_index_map
def extract_thread_binding(
self,
thread_id: PrimExpr,
is_m_first: bool | None = None) -> tuple[PrimExpr, PrimExpr, PrimExpr]:
"""
is_m_first: True if the thread binding is in the form of (tx, warp_n, warp_m)
which represents [warp_size, block_row_warps (split n), block_col_warps (split m)]
Otherwise, it is in the form of [warp_size, block_col_warps (split m), block_row_warps (split n)]
"""
WARP_SIZE = self.WARP_SIZE
block_row_warps = self.block_row_warps
block_col_warps = self.block_col_warps
# if is_m_first is None, then use the default value
if is_m_first is None:
is_m_first = self.is_m_first
if is_m_first:
lane_id, warp_n, warp_m = (
thread_id % WARP_SIZE,
(thread_id // WARP_SIZE) % block_col_warps,
(thread_id // (WARP_SIZE * block_col_warps)) % block_row_warps,
)
return lane_id, warp_n, warp_m
else:
lane_id, warp_m, warp_n = (
thread_id % WARP_SIZE,
(thread_id // WARP_SIZE) % block_row_warps,
(thread_id // (WARP_SIZE * block_row_warps)) % block_col_warps,
)
return lane_id, warp_n, warp_m
def ldmatrix_a(self,
A_local_buf: Buffer,
A_shared_buf: Buffer | BufferRegion,
ki: PrimExpr,
rk: PrimExpr | None = 0):
warp_row_tiles = self.warp_row_tiles
warp_rows = self.warp_rows
chunk = self.chunk
micro_size_x = self.micro_size_x
micro_size_k = self.micro_size_k
local_size_a = self.local_size_a
a_transposed = self.a_transposed
thread_binding = self.get_thread_binding()
assert not a_transposed, "A must be not transposed"
mma_load_layout = mma_load_a_32x4_to_shared_16x4_layout
# legalize shared buffer to region
A_region = to_buffer_region(A_shared_buf)
A_buf = A_region.buffer
A_base0 = A_region.region[-2].min
A_base1 = A_region.region[-1].min
@T.macro
def _warp_ldmatrix_a(
A_local_buf,
A_shared_buf,
ki,
thread_binding,
rk=0,
):
tx, _, warp_m = self.extract_thread_binding(thread_binding)
for i in T.serial(warp_rows):
# Assign A_shared_buf_elem
wi, wk = warp_m * warp_row_tiles + i * micro_size_x, rk * chunk + ki * micro_size_k
for j in T.vectorized(local_size_a):
mi, mk = mma_load_layout(tx, j)
A_local_buf[i * local_size_a + j] = A_buf[A_base0 + wi + mi, A_base1 + wk + mk]
return _warp_ldmatrix_a(A_local_buf, A_region, ki, thread_binding, rk)
def ldmatrix_b(self,
B_local_buf: Buffer,
B_shared_buf: Buffer | BufferRegion,
ki: PrimExpr,
rk: PrimExpr | None = 0):
warp_col_tiles = self.warp_col_tiles
warp_cols = self.warp_cols
chunk = self.chunk
micro_size_y = self.micro_size_y
micro_size_k = self.micro_size_k
local_size_b = self.local_size_b
b_transposed = self.b_transposed
thread_binding = self.get_thread_binding()
mma_load_layout = mma_load_b_32x4_to_shared_16x4_layout_trans if b_transposed else mma_load_b_32x4_to_shared_4x16_layout
# legalize shared buffer to region
B_region = to_buffer_region(B_shared_buf)
B_buf = B_region.buffer
B_base0 = B_region.region[-2].min
B_base1 = B_region.region[-1].min
@T.macro
def _warp_ldmatrix_b(
B_local_buf,
B_shared_buf,
ki,
thread_binding,
rk=0,
):
tx, warp_n, _ = self.extract_thread_binding(thread_binding)
for i in T.serial(warp_cols):
# Assign B_shared_elem
wi, wk = (
warp_n * warp_col_tiles + i * micro_size_y,
rk * chunk + ki * micro_size_k,
)
# load 16x32 data from shared buffer to local buffer
# must be transposed.
for j in T.vectorized(local_size_b):
if b_transposed:
mi, mk = mma_load_layout(tx, j)
B_local_buf[i * local_size_b + j] = B_buf[B_base0 + wi + mi,
B_base1 + wk + mk]
else:
mk, mi = mma_load_layout(tx, j)
B_local_buf[i * local_size_b + j] = B_buf[B_base0 + wk + mk,
B_base1 + wi + mi]
return _warp_ldmatrix_b(B_local_buf, B_region, ki, thread_binding, rk)
def mma(self,
A_local_buf: Buffer,
B_local_buf: Buffer,
C_local_buf: Buffer,
k_inner: PrimExpr | None = 0):
warp_rows = self.warp_rows
warp_cols = self.warp_cols
local_size_a = self.local_size_a
local_size_b = self.local_size_b
local_size_out = self.local_size_out
a_dtype_abbrv = self.a_dtype_abbrv
b_dtype_abbrv = self.b_dtype_abbrv
accum_dtype_abbrv = self.accum_dtype_abbrv
mma_prefix = self.mma_prefix
a_is_fragment = is_fragment(A_local_buf)
b_is_fragment = is_fragment(B_local_buf)
a_local_stride: PrimExpr = k_inner * warp_rows * local_size_a if a_is_fragment else 0
b_local_stride: PrimExpr = k_inner * warp_cols * local_size_b if b_is_fragment else 0
a_major = "col" if self.a_transposed else "row"
b_major = "col" if self.b_transposed else "row"
@T.macro
def _warp_mma(A_local_buf, B_local_buf, C_local_buf):
for i, j in T.grid(warp_rows, warp_cols):
T.ptx_mma_sm70(
mma_prefix,
a_major,
b_major,
a_dtype_abbrv,
b_dtype_abbrv,
accum_dtype_abbrv,
A_local_buf.data,
a_local_stride + i * local_size_a,
B_local_buf.data,
b_local_stride + j * local_size_b,
C_local_buf.data,
i * warp_cols * local_size_out + j * local_size_out,
)
return _warp_mma(A_local_buf, B_local_buf, C_local_buf)
def make_mma_load_layout(self,
local_buf: Buffer,
matrix: Literal["A", "B"] = "A") -> T.Fragment:
"""
Create a layout function for storing MMA results into a fragment buffer.
This layout is used in conjunction with `inverse_mma_store_layout` to
map fragment indices to threads and local indices.
Parameters
----------
local_buf : tir.Buffer
The local buffer representing a fragment of a matrix.
Returns
-------
T.Fragment
A fragment object that describes how threads and indices
in `local_buf` are laid out.
Raises
------
AssertionError
If `local_buf` is not detected to be a fragment buffer.
"""
from tilelang.utils import is_fragment
assert matrix in ["A", "B"], "matrix should be either A or B"
matrix_is_a: bool = matrix == "A"
matrix_is_b: bool = matrix == "B"
dtype = self.a_dtype if matrix_is_a else self.b_dtype
dtype_bits = DataType(dtype).bits
transposed = self.a_transposed if matrix_is_a else self.b_transposed
# s represents spatial axis
# r represents reduction axis
# sr represents the two dims are spatial + reduction
# rs represents the two dims are reduction + spatial
# sr also can represent a non-transposed basic layout
# then rs also can represent a transposed basic layout
transform_func_sr_a: Callable = None
transform_func_sr_b: Callable = None
transform_func_rs_b: Callable = None
if dtype_bits == 16:
transform_func_sr_a = shared_16x4_to_mma_a_32x4_layout
transform_func_sr_b = shared_16x4_to_mma_b_32x4_layout_trans
transform_func_rs_b = shared_4x16_to_mma_b_32x4_layout
else:
raise ValueError(f"Unsupported dtype {dtype}")
is_sr_conditions = [False]
is_sr_conditions.append(matrix_is_a and not transposed)
is_sr_conditions.append(matrix_is_b and transposed)
is_sr_axis_order = any(is_sr_conditions)
# the layout of mma.sync is row.col.
# so the b matrix expected a transposed basic layout
transform_func: Callable = None
if matrix_is_a:
transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a(
j, i)
elif matrix_is_b:
transform_func = transform_func_sr_b if is_sr_axis_order else lambda i, j: transform_func_rs_b(
i, j)
else:
raise ValueError(f"Unsupported matrix {matrix}")
assert is_fragment(local_buf), f"local_buf must be a fragment, but got {local_buf.scope()}"
if matrix_is_a:
micro_size_s, micro_size_r = self.micro_size_x, self.micro_size_k
else:
micro_size_r, micro_size_s = self.micro_size_k, self.micro_size_y
block_row_warps, block_col_warps = (
self.block_row_warps,
self.block_col_warps,
)
inverse_mma_load_layout = IndexMap.from_func(transform_func, index_dtype="int32")
def forward(i: int, j: int, rep: int) -> int:
"""
Given the row index `i` and column index `j` in the fragment,
"""
lane_id, local_id = inverse_mma_load_layout.map_indices([i, j, rep])
return lane_id, local_id
base_fragment = T.Fragment(
[micro_size_s, micro_size_r] if is_sr_axis_order else [micro_size_r, micro_size_s],
forward_fn=forward,
replicate=2)
warp_rows, warp_cols = self.warp_rows, self.warp_cols
chunk = self.chunk
warp_s = warp_rows if matrix_is_a else warp_cols
warp_r = chunk // micro_size_r
block_s = block_row_warps if matrix_is_a else block_col_warps
replicate = block_col_warps if matrix_is_a else block_row_warps
if is_sr_axis_order:
warp_fragment = base_fragment.repeat([warp_s, warp_r],
repeat_on_thread=False,
lower_dim_first=False)
if matrix_is_a:
block_fragment = warp_fragment.repeat([block_s, 1],
repeat_on_thread=True,
lower_dim_first=True).replicate(replicate)
elif matrix_is_b:
block_fragment = warp_fragment.replicate(replicate).repeat([block_s, 1],
repeat_on_thread=True,
lower_dim_first=True)
else:
raise ValueError(f"Unsupported matrix type {matrix}")
else:
warp_fragment = base_fragment.repeat([warp_r, warp_s],
repeat_on_thread=False,
lower_dim_first=True)
if matrix_is_a:
block_fragment = warp_fragment.repeat([1, block_s],
repeat_on_thread=True,
lower_dim_first=True).replicate(replicate)
elif matrix_is_b:
block_fragment = warp_fragment.replicate(replicate).repeat([1, block_s],
repeat_on_thread=True,
lower_dim_first=True)
else:
raise ValueError(f"Unsupported matrix type {matrix}")
return block_fragment
def make_mma_store_layout(self, local_buf: Buffer) -> T.Fragment:
"""
Create a layout function for storing MMA results into a fragment buffer.
This layout is used in conjunction with `inverse_mma_store_layout` to
map fragment indices to threads and local indices.
Parameters
----------
local_buf : tir.Buffer
The local buffer representing a fragment of a matrix.
Returns
-------
T.Fragment
A fragment object that describes how threads and indices
in `local_buf` are laid out.
Raises
------
AssertionError
If `local_buf` is not detected to be a fragment buffer.
"""
from tilelang.utils import is_fragment
shape = local_buf.shape
inverse_mma_store_layout = self.get_store_index_map(inverse=True)
assert is_fragment(local_buf), "local_buf must be a fragment"
micro_size_x, micro_size_y = self.micro_size_x, self.micro_size_y
local_size_out = self.local_size_out
block_row_warps, block_col_warps = self.block_row_warps, self.block_col_warps
warp_rows, warp_cols = self.warp_rows, self.warp_cols
warp_size = self.WARP_SIZE
is_m_first = self.is_m_first
def forward_thread(i: int, j: int) -> int:
"""
Given the row index `i` and column index `j` in the fragment,
map them to a thread index according to `inverse_mma_store_layout`.
"""
# the upper bounds of i and j are block_row_warps * warp_rows * micro_size_x and block_col_warps * warp_cols * micro_size_y
# the upper bounds of block_row_warps and block_col_warps are warp_rows and warp_cols
block_i, block_j = (i // micro_size_x) // warp_rows, (j // micro_size_y) // warp_cols
# upper bounds of mma_i and mma_j are micro_size_x and micro_size_y
mma_i, mma_j = i % micro_size_x, j % micro_size_y
lane_id, _ = inverse_mma_store_layout.map_indices([mma_i, mma_j])
if is_m_first:
thread_id = block_i * (block_col_warps * warp_cols) + block_j * warp_size + lane_id
else:
thread_id = block_j * (block_row_warps * warp_size) + block_i * warp_size + lane_id
return thread_id
def forward_index(i: int, j: int) -> int:
"""
Given the row index `i` and column index `j` in the fragment,
map them to a local index in a single thread according
to `inverse_mma_store_layout`.
"""
# the upper bounds of i and j are block_row_warps * warp_rows * micro_size_x and block_col_warps * warp_cols * micro_size_y
# the upper bounds of warp_i and warp_j are warp_rows and warp_cols
warp_i, warp_j = (i // micro_size_x) % warp_rows, (j // micro_size_y) % warp_cols
# upper bounds of mma_i and mma_j are micro_size_x and micro_size_y
mma_i, mma_j = i % micro_size_x, j % micro_size_y
_, local_id = inverse_mma_store_layout.map_indices([mma_i, mma_j])
return warp_i * (warp_cols * local_size_out) + warp_j * local_size_out + local_id
return T.Fragment(
shape,
forward_thread_fn=forward_thread,
forward_index_fn=forward_index,
)
from __future__ import annotations
from enum import IntEnum
import tilelang.language as T
from .mma_macro_generator import TensorCoreIntrinEmitter as MMAIntrinEmitter
from tvm import DataType
from tvm.tir import PrimExpr, Buffer, Var, BufferLoad, BufferRegion
from tilelang import tvm as tvm
from tilelang import _ffi_api
from tilelang.utils import is_tensor_memory
from tilelang.layout import (
Layout,
make_full_bank_swizzled_layout,
make_half_bank_swizzled_layout,
make_quarter_bank_swizzled_layout,
make_linear_layout,
)
from tvm.runtime import convert
lift = convert
class SwizzleMode(IntEnum):
# SWIZZLE_NONE = 0, SWIZZLE_32B = 3, SWIZZLE_64B = 2, SWIZZLE_128B = 1
NONE = 0
SWIZZLE_128B = 2
SWIZZLE_64B = 4
SWIZZLE_32B = 6
def is_none(self) -> bool:
return self == SwizzleMode.NONE
def is_swizzle_32b(self) -> bool:
return self == SwizzleMode.SWIZZLE_32B
def is_swizzle_64b(self) -> bool:
return self == SwizzleMode.SWIZZLE_64B
def is_swizzle_128b(self) -> bool:
return self == SwizzleMode.SWIZZLE_128B
def swizzle_byte_size(self) -> int:
if self.is_swizzle_32b():
return 32
elif self.is_swizzle_64b():
return 64
elif self.is_swizzle_128b():
return 128
else:
return 1
def swizzle_atom_size(self) -> int:
if self.is_swizzle_32b():
return 32 // 16
elif self.is_swizzle_64b():
return 64 // 16
elif self.is_swizzle_128b():
return 128 // 16
else:
return 1
# derive from MMAIntrinEmitter as some layouts are the same
class TensorCoreIntrinEmitter(MMAIntrinEmitter):
"""
To eliminate Python syntax within TIR Macro.
"""
# should be rewritten to support dynamic k_dim
tcgen05_prefix: str
a_shared_layout: Layout = None
b_shared_layout: Layout = None
def __init__(
self,
a_dtype: str = "float16",
b_dtype: str = "float16",
accum_dtype: str = "float16",
a_transposed: bool = False,
b_transposed: bool = False,
block_row_warps: int = 2,
block_col_warps: int = 2,
warp_row_tiles: int = 8,
warp_col_tiles: int = 8,
chunk: int = 16,
reduce_k: int = 1,
num_elems_per_byte: int = 1,
is_m_first: bool = False,
thread_var: Var | None = None,
):
super().__init__(a_dtype, b_dtype, accum_dtype, a_transposed, b_transposed, block_row_warps,
block_col_warps, warp_row_tiles, warp_col_tiles, chunk, reduce_k,
num_elems_per_byte, is_m_first, thread_var)
def _assign_a_shared_layout(self, layout: Layout):
self.a_shared_layout = layout
return self
def _assign_b_shared_layout(self, layout: Layout):
self.b_shared_layout = layout
return self
def _initialize_micro_size(self, m_dim: int = 16, k_dim: int = 16):
warp_row_tiles = self.warp_row_tiles
warp_col_tiles = self.warp_col_tiles
# For tcgen05, warp_row_tiles is 8 as we can use .ws to support m32
assert warp_row_tiles >= 8, f"warp_row_tiles must be greater than 8, got {warp_row_tiles}"
assert warp_row_tiles % 8 == 0, f"warp_row_tiles must be divisible by 8, got {warp_row_tiles}"
assert warp_col_tiles >= 8, f"warp_col_tiles must be greater than 8, got {warp_col_tiles}"
assert warp_col_tiles % 8 == 0, f"warp_col_tiles must be divisible by 8, got {warp_col_tiles}"
# four warps per block
self.warp_rows = warp_row_tiles // 8
if warp_col_tiles % 16 == 0:
self.n_dim = 16
self.micro_size_y = 16
self.warp_cols = warp_col_tiles // 16
else:
# must be divisible by 8
self.n_dim = 8
self.micro_size_y = 8
self.warp_cols = warp_col_tiles // 8
self.micro_size_x = m_dim
self.micro_size_k = k_dim
def _determinate_swizzle_mode(self, buffer: Buffer, layout: Layout) -> SwizzleMode:
# same behavior to src/layout/gemm_layouts.cc::makeGemmABLayoutHopper
if layout is None or layout.is_equal(make_linear_layout(buffer)):
return SwizzleMode.NONE
elif layout.is_equal(make_quarter_bank_swizzled_layout(buffer)):
return SwizzleMode.SWIZZLE_32B
elif layout.is_equal(make_half_bank_swizzled_layout(buffer)):
return SwizzleMode.SWIZZLE_64B
elif layout.is_equal(make_full_bank_swizzled_layout(buffer)):
return SwizzleMode.SWIZZLE_128B
else:
raise ValueError(f"Unsupported swizzle mode: {layout}")
def tcgen05mma(self,
A_buf: Buffer,
B_buf: Buffer,
C_local_buf: Buffer,
mbar,
clear_accum: PrimExpr = False):
if is_tensor_memory(A_buf):
return self.tcgen05mma_rs(A_buf, B_buf, C_local_buf, clear_accum)
accum_dtype = self.accum_dtype
m_dim = self.block_row_warps * self.warp_row_tiles
micro_size_k = self.micro_size_k
k_dim, n_dim = self.chunk, self.block_col_warps * self.warp_col_tiles
scale_in_a = 1
scale_in_b = 1
assert k_dim >= micro_size_k, f"k_dim must be greater than or equal to {micro_size_k}, got k_dim: {k_dim}"
a_is_k_major = not self.a_transposed
b_is_k_major = self.b_transposed
a_swizzle_mode = self._determinate_swizzle_mode(A_buf, self.a_shared_layout)
b_swizzle_mode = self._determinate_swizzle_mode(B_buf, self.b_shared_layout)
elems_in_bits = DataType(self.a_dtype).bits
elems_in_bytes = elems_in_bits // 8
a_swizzle_atom_elems = a_swizzle_mode.swizzle_byte_size() // elems_in_bytes
b_swizzle_atom_elems = n_dim if b_swizzle_mode.is_none(
) else b_swizzle_mode.swizzle_byte_size() // elems_in_bytes
accum_dtype_in_bits = DataType(accum_dtype).bits
meta = self.get_tcgen5_mma_meta(m_dim, n_dim, k_dim)
if len(meta) != 3:
raise ValueError(
f"Unsupported TCGEN5MMA configuration for desc generation: M={m_dim}, N={n_dim}, "
f"K={k_dim}, A dtype={self.a_dtype}, accum dtype={self.accum_dtype}")
atom_m, atom_n, atom_k = (int(x) for x in meta)
enable_ws = atom_m != 128
# by default, we utilize non-swizzle layout offset
a_leading_byte_offset = (8 * 8 * elems_in_bytes) if a_is_k_major else (8 * m_dim *
elems_in_bytes)
a_stride_byte_offset = (8 * k_dim * elems_in_bytes) if a_is_k_major else (8 * 8 *
elems_in_bytes)
if not a_swizzle_mode.is_none():
# swizzle mode doesn't require LBO/SBO to be 1
# https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-leading-dimension-byte-offset
if a_is_k_major:
a_leading_byte_offset = 16
a_stride_byte_offset = 8 * a_swizzle_mode.swizzle_byte_size()
else:
# MN Major
# LBO represents the distance between two atoms along the M dimension
# SBO represents the distance between two atoms along the K dimension
a_m_axis_atoms = m_dim // a_swizzle_atom_elems
if a_m_axis_atoms <= 1:
a_leading_byte_offset = 0
else:
a_leading_byte_offset = k_dim * a_swizzle_mode.swizzle_byte_size()
if a_m_axis_atoms <= 1:
a_stride_byte_offset = 8 * elems_in_bytes * m_dim
else:
a_stride_byte_offset = 8 * elems_in_bytes * a_swizzle_atom_elems
b_leading_byte_offset = (8 * 8 * elems_in_bytes) if b_is_k_major else (8 * n_dim *
elems_in_bytes)
b_stride_byte_offset = (8 * k_dim *
elems_in_bytes) if b_is_k_major else (0 if n_dim == 8 else
(8 * 8 * elems_in_bytes))
if not b_swizzle_mode.is_none():
# swizzle mode doesn't require LBO/SBO to be 1
# https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-leading-dimension-byte-offset
if b_is_k_major:
b_leading_byte_offset = 16
b_stride_byte_offset = 8 * b_swizzle_mode.swizzle_byte_size()
else:
# MN Major, K * N
# LBO represents the distance between two atoms along the N dimension
# SBO represents the distance between two atoms along the K dimension
b_n_axis_atoms = n_dim // b_swizzle_atom_elems
if b_n_axis_atoms <= 1:
b_leading_byte_offset = 0
else:
b_leading_byte_offset = 8 * 8 * elems_in_bytes * k_dim
if b_n_axis_atoms <= 1:
b_stride_byte_offset = 8 * elems_in_bytes * n_dim
else:
b_stride_byte_offset = 8 * elems_in_bytes * b_swizzle_atom_elems
# for example, if [n, k] where k is 128, we should split it into 2 atoms
# where max specially handles the case when n_dim is 8.
ak_atom_size = max(a_swizzle_atom_elems // micro_size_k, 1)
bk_atom_size = max(b_swizzle_atom_elems // micro_size_k, 1)
instr_desc = self.get_tcgen5_instr_desc(
atom_m,
atom_n,
atom_k,
a_is_k_major,
b_is_k_major,
scale_in_a,
scale_in_b,
)
# Allocate an instruction descriptor wrapper and initialize it
a_dtype_abbrv = self.a_dtype_abbrv
mask_zero = T.Cast("int32", 0)
mask0 = mask1 = mask2 = mask3 = mask_zero
num_inst_m = 4 * self.warp_row_tiles // atom_m
num_inst_n = self.warp_col_tiles // atom_n
# Helper to allow BufferRegion/BufferLoad as inputs
def access_ptr_from(buffer_or_load_or_region, access_type: str = "r"):
if isinstance(buffer_or_load_or_region, Buffer):
return buffer_or_load_or_region.access_ptr(access_type)
elif isinstance(buffer_or_load_or_region, BufferLoad):
buffer_load = buffer_or_load_or_region
offset, stride = 0, 1
buffer = buffer_load.buffer
for i, shape in enumerate(reversed(buffer.shape)):
indice = buffer_load.indices[len(buffer_load.indices) - i - 1]
if isinstance(indice, (tvm.tir.IntImm, tvm.tir.PrimExpr)):
offset += indice * stride
elif isinstance(indice, tvm.tir.Ramp):
offset += indice.base * stride
else:
raise ValueError(f"Unsupported index type: {type(indice)}")
stride *= shape
return buffer.access_ptr(access_type, offset=offset)
elif isinstance(buffer_or_load_or_region, BufferRegion):
buffer_region = buffer_or_load_or_region
buffer = buffer_region.buffer
offset, stride = 0, 1
for i, shape in enumerate(reversed(buffer.shape)):
offset += buffer_region.region[len(buffer_region.region) - i - 1].min * stride
stride *= shape
return buffer.access_ptr(access_type, offset=offset)
else:
raise ValueError(f"Unsupported buffer type: {type(buffer_or_load_or_region)}")
@T.macro
def _warp_mma(A_buf, B_buf, C_local_buf, mbar):
# Allocate SMEM descriptors for A and B
desc_a = T.alloc_tcgen05_smem_desc()
desc_b = T.alloc_tcgen05_smem_desc()
A_ptr = access_ptr_from(A_buf, "r")
B_ptr = access_ptr_from(B_buf, "r")
T.initialize_tcgen05_descriptor(
desc_a,
A_ptr,
int(a_leading_byte_offset >> 4),
int(a_stride_byte_offset >> 4),
0,
False,
int(a_swizzle_mode),
)
T.initialize_tcgen05_descriptor(
desc_b,
B_ptr,
int(b_leading_byte_offset >> 4),
int(b_stride_byte_offset >> 4),
0,
False,
int(b_swizzle_mode),
)
tmem_col_step = atom_n // (128 // atom_m)
for j in T.unroll(num_inst_n):
for i in T.unroll(num_inst_m):
for ki in T.unroll(0, (k_dim // micro_size_k)):
scale_out = T.Select(ki != 0, 1, T.Select(clear_accum, 0, 1))
A_elem_offset = (
ki % ak_atom_size
) * micro_size_k + i * atom_m * a_swizzle_atom_elems + (
ki // ak_atom_size
) * m_dim * a_swizzle_atom_elems if a_is_k_major else i * atom_m * k_dim + ki * a_swizzle_atom_elems * micro_size_k
B_elem_offset = (ki // bk_atom_size) * n_dim * b_swizzle_atom_elems + (
ki % bk_atom_size
) * micro_size_k + j * atom_n * b_swizzle_atom_elems if b_is_k_major else (
ki * b_swizzle_atom_elems * micro_size_k + j * atom_n *
(k_dim if n_dim // b_swizzle_atom_elems > 1 else 1))
A_byte_offset = A_elem_offset * elems_in_bytes
B_byte_offset = B_elem_offset * elems_in_bytes
C_offset = (i * n_dim + j * tmem_col_step
) * accum_dtype_in_bits // 32 # 32 bits per tmem bank
T.ptx_tcgen05_mma_ss(
a_dtype_abbrv,
desc_a.data,
A_byte_offset,
desc_b.data,
B_byte_offset,
C_local_buf.data,
C_offset,
instr_desc,
scale_out,
mask0,
mask1,
mask2,
mask3,
enable_ws,
)
T.tcgen05_mma_arrive(mbar)
return _warp_mma(A_buf, B_buf, C_local_buf, mbar)
def make_mma_load_layout(self, local_buf: Buffer, matrix: str = "A") -> T.Fragment:
raise NotImplementedError
def make_mma_store_layout(self, tmem_buf: Buffer) -> Layout:
"""
Create the TCGEN5 tensor-memory layout used to store MMA accumulators.
Parameters
----------
tmem_buf : tir.Buffer
The local buffer representing tensormemory of a mma's output
Returns
-------
Layout
Layout object describing how logical (i, j) coordinates map to the
swizzled tensor-memory offsets required by TCGEN5MMA.
Raises
------
AssertionError
If `tmem_buf` is not detected to be a tensor-memory buffer.
"""
assert is_tensor_memory(tmem_buf), "tmem_buf must reside in tensor memory (shared.tmem)"
if len(tmem_buf.shape) != 2:
raise ValueError(
f"TCGEN5MMA expects a 2-D tensor-memory buffer, got shape {tmem_buf.shape}")
m = int(tmem_buf.shape[0])
n = int(tmem_buf.shape[1])
k = int(self.chunk)
meta = self.get_tcgen5_mma_meta(m, n, k)
if len(meta) != 3:
raise ValueError(f"Unsupported TCGEN5MMA configuration: M={m}, N={n}, K={k}, "
f"A dtype={self.a_dtype}, accum dtype={self.accum_dtype}")
atom_m, atom_n, _ = (int(x) for x in meta)
if m % atom_m != 0 or n % atom_n != 0:
raise ValueError(
f"Invalid TCGEN5MMA store layout for shape ({m}, {n}) with atoms ({atom_m}, {atom_n})"
)
def forward(i: PrimExpr, j: PrimExpr):
atom_idx = (i // atom_m) + (j // atom_n) * (m // atom_m)
ai = i % atom_m
aj = j % atom_n
if atom_m == 128:
# Layout D
return [
ai,
aj + atom_idx * atom_n,
]
if atom_m == 64:
# Layout E (.ws variant)
half_atom_n = atom_n // 2
return [
(ai // 32) * 32 + ai % 32 + (aj // half_atom_n) * 64,
(aj % half_atom_n) + atom_idx * half_atom_n,
]
if atom_m == 32:
# Layout G
quarter_atom_n = atom_n // 4
return [
ai % 32 + (aj // quarter_atom_n) * 32,
(aj % quarter_atom_n) + atom_idx * quarter_atom_n,
]
raise ValueError(f"Unsupported TCGEN5 atom_m={atom_m}")
return Layout([m, n], forward)
def get_tcgen5_mma_meta(self, m: int, n: int, k: int):
return _ffi_api.get_tcgen5_mma_meta(
int(m), int(n), int(k), DataType(self.a_dtype), DataType(self.accum_dtype))
def get_tcgen5_instr_desc(self, atom_m: int, atom_n: int, atom_k: int, a_is_k_major: bool,
b_is_k_major: bool, scale_in_a: int, scale_in_b: int) -> PrimExpr:
desc = _ffi_api.get_tcgen5_instr_desc(
atom_m,
atom_n,
atom_k,
DataType(self.a_dtype),
DataType(self.accum_dtype),
a_is_k_major,
b_is_k_major,
scale_in_a,
scale_in_b,
)
return lift(desc)
...@@ -8,6 +8,7 @@ from .mma_layout import ( ...@@ -8,6 +8,7 @@ from .mma_layout import (
ldmatrix_32x16_to_shared_16x32_layout_a, ldmatrix_32x16_to_shared_16x32_layout_a,
ldmatrix_32x16_to_shared_16x32_layout_b, ldmatrix_32x16_to_shared_16x32_layout_b,
mma_store_32x8_to_shared_16x16_layout, mma_store_32x8_to_shared_16x16_layout,
mma_store_32x2_to_shared_8x8_layout_fp64,
) )
from .mfma_layout import (thread_id_shared_access_64x4_to_16x16_layout_C_n_m) from .mfma_layout import (thread_id_shared_access_64x4_to_16x16_layout_C_n_m)
...@@ -82,6 +83,10 @@ def mma_store_index_map(thread_id, local_id): ...@@ -82,6 +83,10 @@ def mma_store_index_map(thread_id, local_id):
return mma_store_32x8_to_shared_16x16_layout(thread_id, local_id) return mma_store_32x8_to_shared_16x16_layout(thread_id, local_id)
def mma_store_index_map_fp64(thread_id, local_id):
return mma_store_32x2_to_shared_8x8_layout_fp64(thread_id, local_id)
def mfma_store_index_map(thread_id, local_id): def mfma_store_index_map(thread_id, local_id):
return thread_id_shared_access_64x4_to_16x16_layout_C_n_m(thread_id, local_id) return thread_id_shared_access_64x4_to_16x16_layout_C_n_m(thread_id, local_id)
......
...@@ -4,8 +4,9 @@ from enum import IntEnum ...@@ -4,8 +4,9 @@ from enum import IntEnum
from typing import Callable from typing import Callable
from .mma_macro_generator import TensorCoreIntrinEmitter as MMAIntrinEmitter from .mma_macro_generator import TensorCoreIntrinEmitter as MMAIntrinEmitter
from tvm import DataType from tvm import DataType
from tvm.tir import PrimExpr, Buffer, Var, IndexMap from tvm.tir import PrimExpr, Buffer, Var, IndexMap, BufferRegion
from tilelang.utils import is_fragment from tilelang.utils import is_fragment, retrive_ptr_from_buffer_region, is_full_region
from math import gcd
from tilelang.layout import ( from tilelang.layout import (
Layout, Layout,
make_full_bank_swizzled_layout, make_full_bank_swizzled_layout,
...@@ -70,6 +71,11 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): ...@@ -70,6 +71,11 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
# should be rewritten to support dynamic k_dim # should be rewritten to support dynamic k_dim
wgmma_prefix: str wgmma_prefix: str
# wgmma instruction M dimension
wgmma_inst_m: int
# wgmma instruction N dimension
wgmma_inst_n: int
a_shared_layout: Layout = None a_shared_layout: Layout = None
b_shared_layout: Layout = None b_shared_layout: Layout = None
...@@ -104,9 +110,18 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): ...@@ -104,9 +110,18 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
return self return self
def _initialize_wgmma_prefix(self, n_dim: int = 16): def _initialize_wgmma_prefix(self, n_dim: int = 16):
inst_m, inst_n = 64, self.block_col_warps * self.warp_col_tiles inst_m, inst_n = 64, gcd(self.warp_col_tiles, 256)
assert inst_n % 8 == 0, (
f"inst_n must be a multiple of 8, got {inst_n} "
f"(block_col_warps={self.block_col_warps}, warp_col_tiles={self.warp_col_tiles})")
# Validate inst_n: Hopper WGMMA supports n in [8, 256] and multiple of 8
assert 8 <= inst_n <= 256, (
f"inst_n must be within [8, 256], got {inst_n} "
f"(block_col_warps={self.block_col_warps}, warp_col_tiles={self.warp_col_tiles})")
# 256 bits per instruction # 256 bits per instruction
inst_k = 256 // DataType(self.a_dtype).bits inst_k = 256 // DataType(self.a_dtype).bits
self.wgmma_inst_m = inst_m
self.wgmma_inst_n = inst_n
self.wgmma_prefix = f"m{inst_m}n{inst_n}k{inst_k}" self.wgmma_prefix = f"m{inst_m}n{inst_n}k{inst_k}"
def _initialize_micro_size(self, m_dim: int = 16, k_dim: int = 16): def _initialize_micro_size(self, m_dim: int = 16, k_dim: int = 16):
...@@ -146,13 +161,14 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): ...@@ -146,13 +161,14 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
raise ValueError(f"Unsupported swizzle mode: {layout}") raise ValueError(f"Unsupported swizzle mode: {layout}")
def wgmma(self, def wgmma(self,
A_buf: Buffer, A_region: BufferRegion,
B_buf: Buffer, B_region: BufferRegion,
C_local_buf: Buffer, C_region: BufferRegion,
clear_accum: PrimExpr = False): clear_accum: PrimExpr = False,
wg_wait: int = 0):
if is_fragment(A_buf): if is_fragment(A_region):
return self.wgmma_rs(A_buf, B_buf, C_local_buf, clear_accum) return self.wgmma_rs(A_region, B_region, C_region, clear_accum, wg_wait)
local_size_out = self.local_size_out local_size_out = self.local_size_out
a_dtype_abbrv = self.a_dtype_abbrv a_dtype_abbrv = self.a_dtype_abbrv
...@@ -164,7 +180,6 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): ...@@ -164,7 +180,6 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
micro_size_k = self.micro_size_k micro_size_k = self.micro_size_k
k_dim, n_dim = self.chunk, self.block_col_warps * self.warp_col_tiles k_dim, n_dim = self.chunk, self.block_col_warps * self.warp_col_tiles
wgmma_prefix = self.wgmma_prefix wgmma_prefix = self.wgmma_prefix
scale_out = not clear_accum
scale_in_a = 1 scale_in_a = 1
scale_in_b = 1 scale_in_b = 1
...@@ -173,8 +188,8 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): ...@@ -173,8 +188,8 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
a_is_k_major = not self.a_transposed a_is_k_major = not self.a_transposed
b_is_k_major = self.b_transposed b_is_k_major = self.b_transposed
a_swizzle_mode = self._determinate_swizzle_mode(A_buf, self.a_shared_layout) a_swizzle_mode = self._determinate_swizzle_mode(A_region, self.a_shared_layout)
b_swizzle_mode = self._determinate_swizzle_mode(B_buf, self.b_shared_layout) b_swizzle_mode = self._determinate_swizzle_mode(B_region, self.b_shared_layout)
elems_in_bits = DataType(self.a_dtype).bits elems_in_bits = DataType(self.a_dtype).bits
elems_in_bytes = elems_in_bits // 8 elems_in_bytes = elems_in_bits // 8
...@@ -182,6 +197,8 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): ...@@ -182,6 +197,8 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
a_swizzle_atom_elems = a_swizzle_mode.swizzle_byte_size() // elems_in_bytes a_swizzle_atom_elems = a_swizzle_mode.swizzle_byte_size() // elems_in_bytes
b_swizzle_atom_elems = n_dim if b_swizzle_mode.is_none( b_swizzle_atom_elems = n_dim if b_swizzle_mode.is_none(
) else b_swizzle_mode.swizzle_byte_size() // elems_in_bytes ) else b_swizzle_mode.swizzle_byte_size() // elems_in_bytes
accum_bits = DataType(accum_dtype).bits
accum_regs = ((m_dim // 64) * warp_cols * local_size_out * accum_bits + 31) // 32
# by default, we utilize non-swizzle layout offset # by default, we utilize non-swizzle layout offset
a_leading_byte_offset = (8 * 8 * elems_in_bytes) if a_is_k_major else (8 * m_dim * a_leading_byte_offset = (8 * 8 * elems_in_bytes) if a_is_k_major else (8 * m_dim *
...@@ -240,41 +257,69 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): ...@@ -240,41 +257,69 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
# where max specially handles the case when n_dim is 8. # where max specially handles the case when n_dim is 8.
ak_atom_size = max(a_swizzle_atom_elems // micro_size_k, 1) ak_atom_size = max(a_swizzle_atom_elems // micro_size_k, 1)
bk_atom_size = max(b_swizzle_atom_elems // micro_size_k, 1) bk_atom_size = max(b_swizzle_atom_elems // micro_size_k, 1)
wgmma_inst_m, wgmma_inst_n = self.wgmma_inst_m, self.wgmma_inst_n
num_inst_m = 4 * self.warp_row_tiles // wgmma_inst_m
num_inst_n = self.warp_col_tiles // wgmma_inst_n
thread_binding = self.get_thread_binding()
A_ptr = retrive_ptr_from_buffer_region(A_region)
B_ptr = retrive_ptr_from_buffer_region(B_region)
assert is_full_region(C_region), "Fragment output C must be a full region"
C_buf = C_region.buffer
@T.macro @T.macro
def _warp_mma(A_buf, B_buf, C_local_buf): def _warp_mma(A_ptr, B_ptr, C_buf):
# TODO(lei): inject warpgroup_fence_operand for C_local_buf tx, warp_n, warp_m = self.extract_thread_binding(thread_binding)
desc_a = T.alloc_descriptor()
desc_b = T.alloc_descriptor() desc_a = T.alloc_wgmma_desc()
T.initialize_descriptor(desc_a, A_buf.access_ptr("r"), a_swizzle_mode, desc_b = T.alloc_wgmma_desc()
int(a_leading_byte_offset >> 4), int(a_stride_byte_offset >> 4)) T.initialize_wgmma_descriptor(desc_a, A_ptr, a_swizzle_mode,
T.initialize_descriptor(desc_b, B_buf.access_ptr("r"), b_swizzle_mode, int(a_leading_byte_offset >> 4),
int(b_leading_byte_offset >> 4), int(b_stride_byte_offset >> 4)) int(a_stride_byte_offset >> 4))
T.initialize_wgmma_descriptor(desc_b, B_ptr, b_swizzle_mode,
int(b_leading_byte_offset >> 4),
int(b_stride_byte_offset >> 4))
T.warpgroup_fence_operand(C_buf, num_regs=accum_regs)
T.warpgroup_arrive() T.warpgroup_arrive()
for ki in T.serial(0, (k_dim // micro_size_k)):
for i in T.serial(m_dim // 64): for j in T.unroll(num_inst_n):
A_offset = (ki % ak_atom_size) * micro_size_k + i * 64 * a_swizzle_atom_elems + ( for i in T.unroll(num_inst_m):
ki // ak_atom_size for ki in T.unroll(k_dim // micro_size_k):
) * m_dim * a_swizzle_atom_elems if a_is_k_major else i * 64 * k_dim + ki * a_swizzle_atom_elems * micro_size_k scale_out = T.Select(ki != 0, 1, T.Select(clear_accum, 0, 1))
B_offset = (ki // bk_atom_size) * n_dim * b_swizzle_atom_elems + ( warp_i = (warp_m // 4) * num_inst_m + i
ki % bk_atom_size warp_j = warp_n * num_inst_n + j
) * micro_size_k if b_is_k_major else ki * b_swizzle_atom_elems * micro_size_k A_offset = (
C_offset = i * warp_cols * local_size_out # 4 warps as an unit ki % ak_atom_size
T.ptx_wgmma_ss(accum_dtype, wgmma_prefix, a_is_k_major, b_is_k_major, ) * micro_size_k + warp_i * 64 * a_swizzle_atom_elems + (
a_dtype_abbrv, b_dtype_abbrv, accum_dtype_abbrv, desc_a.data, ki // ak_atom_size
(A_offset * elems_in_bytes) >> 4, desc_b.data, ) * m_dim * a_swizzle_atom_elems if a_is_k_major else warp_i * 64 * k_dim + ki * a_swizzle_atom_elems * micro_size_k
(B_offset * elems_in_bytes) >> 4, C_local_buf.data, C_offset, B_offset = (ki // bk_atom_size) * n_dim * b_swizzle_atom_elems + (
scale_out, scale_in_a, scale_in_b) ki % bk_atom_size
) * micro_size_k + warp_j * wgmma_inst_n * b_swizzle_atom_elems if b_is_k_major else (
ki * b_swizzle_atom_elems * micro_size_k + warp_j * wgmma_inst_n *
(k_dim if n_dim // b_swizzle_atom_elems > 1 else 1))
C_offset = i * warp_cols * local_size_out + j * warp_cols * local_size_out // num_inst_n # 4 warps as an unit
T.ptx_wgmma_ss(accum_dtype, wgmma_prefix, a_is_k_major, b_is_k_major,
a_dtype_abbrv, b_dtype_abbrv, accum_dtype_abbrv, desc_a.data,
(A_offset * elems_in_bytes) >> 4, desc_b.data,
(B_offset * elems_in_bytes) >> 4, C_buf.data, C_offset,
scale_out, scale_in_a, scale_in_b)
T.warpgroup_commit_batch() T.warpgroup_commit_batch()
T.warpgroup_wait(0) if wg_wait >= 0:
T.warpgroup_wait(wg_wait)
T.warpgroup_fence_operand(C_buf, num_regs=accum_regs)
return _warp_mma(A_buf, B_buf, C_local_buf) return _warp_mma(A_ptr, B_ptr, C_buf)
def wgmma_rs(self, def wgmma_rs(self,
A_buf: Buffer, A_region: BufferRegion,
B_buf: Buffer, B_region: BufferRegion,
C_local_buf: Buffer, C_region: BufferRegion,
clear_accum: PrimExpr = False): clear_accum: PrimExpr = False,
wg_wait: int = 0):
local_size_a = self.local_size_a local_size_a = self.local_size_a
local_size_out = self.local_size_out local_size_out = self.local_size_out
a_dtype_abbrv = self.a_dtype_abbrv a_dtype_abbrv = self.a_dtype_abbrv
...@@ -286,75 +331,111 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): ...@@ -286,75 +331,111 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
micro_size_k = self.micro_size_k micro_size_k = self.micro_size_k
k_dim, n_dim = self.chunk, self.block_col_warps * self.warp_col_tiles k_dim, n_dim = self.chunk, self.block_col_warps * self.warp_col_tiles
wgmma_prefix = self.wgmma_prefix wgmma_prefix = self.wgmma_prefix
scale_out = not clear_accum
scale_in_a = 1 scale_in_a = 1
scale_in_b = 1 scale_in_b = 1
assert k_dim >= micro_size_k, f"k_dim must be greater than or equal to {micro_size_k}, got k_dim: {k_dim}" assert k_dim >= micro_size_k, f"k_dim must be greater than or equal to {micro_size_k}, got k_dim: {k_dim}"
elems_in_bytes = DataType(self.a_dtype).bits // 8 elems_in_bytes = DataType(self.a_dtype).bits // 8
a_bits = DataType(self.a_dtype).bits
accum_bits = DataType(accum_dtype).bits
a_regs = ((warp_rows * local_size_a * (k_dim // micro_size_k)) * a_bits + 31) // 32
accum_regs = ((m_dim // 64) * warp_cols * local_size_out * accum_bits + 31) // 32
b_is_k_major = self.b_transposed b_is_k_major = self.b_transposed
b_swizzle_mode = self._determinate_swizzle_mode(B_buf, self.b_shared_layout) b_swizzle_mode = self._determinate_swizzle_mode(B_region, self.b_shared_layout)
b_swizzle_atom_elems = n_dim if b_swizzle_mode.is_none(
) else b_swizzle_mode.swizzle_byte_size() // elems_in_bytes
b_leading_byte_offset = (8 * 8 * elems_in_bytes) if b_is_k_major else (8 * n_dim * b_leading_byte_offset = (8 * 8 * elems_in_bytes) if b_is_k_major else (8 * n_dim *
elems_in_bytes) elems_in_bytes)
b_stride_byte_offset = (8 * k_dim * elems_in_bytes) if b_is_k_major else (8 * 8 * b_stride_byte_offset = (8 * k_dim *
elems_in_bytes) elems_in_bytes) if b_is_k_major else (0 if n_dim == 8 else
(8 * 8 * elems_in_bytes))
if not b_swizzle_mode.is_none(): if not b_swizzle_mode.is_none():
# swizzle mode doesn't require LBO/SBO to be 1 # swizzle mode doesn't require LBO/SBO to be 1
# https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-leading-dimension-byte-offset # https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-leading-dimension-byte-offset
if b_is_k_major: if b_is_k_major:
b_leading_byte_offset = 16 b_leading_byte_offset = 16
b_stride_byte_offset = 8 * b_swizzle_mode.swizzle_byte_size()
else: else:
# MN Major # MN Major
# LBO represents the distance between two atoms along the N dimension # LBO represents the distance between two atoms along the N dimension
# SBO represents the distance between two atoms along the K dimension # SBO represents the distance between two atoms along the K dimension
b_n_axis_atoms = n_dim // (b_swizzle_mode.swizzle_byte_size() // elems_in_bytes) b_n_axis_atoms = n_dim // b_swizzle_atom_elems
if b_n_axis_atoms <= 1: if b_n_axis_atoms <= 1:
b_leading_byte_offset = 0 b_leading_byte_offset = 0
else: else:
b_leading_byte_offset = 8 * b_swizzle_mode.swizzle_atom_size() * ( b_leading_byte_offset = 8 * 8 * elems_in_bytes * k_dim
b_swizzle_mode.swizzle_byte_size() // elems_in_bytes)
if b_n_axis_atoms <= 1: if b_n_axis_atoms <= 1:
b_stride_byte_offset = 8 * elems_in_bytes * n_dim b_stride_byte_offset = 8 * elems_in_bytes * n_dim
else: else:
b_stride_byte_offset = 8 * elems_in_bytes * ( b_stride_byte_offset = 8 * elems_in_bytes * b_swizzle_atom_elems
b_swizzle_mode.swizzle_byte_size() // elems_in_bytes)
bk_atom_size = max(b_swizzle_atom_elems // micro_size_k, 1)
wgmma_inst_m, wgmma_inst_n = self.wgmma_inst_m, self.wgmma_inst_n
num_inst_m = 4 * self.warp_row_tiles // wgmma_inst_m
num_inst_n = self.warp_col_tiles // wgmma_inst_n
thread_binding = self.get_thread_binding()
assert is_full_region(A_region), "Fragment input A must be a full region"
assert is_full_region(C_region), "Fragment output C must be a full region"
A_buf = A_region.buffer
B_ptr = retrive_ptr_from_buffer_region(B_region)
C_buf = C_region.buffer
@T.macro @T.macro
def _warp_mma(A_buf, B_buf, C_local_buf): def _warp_mma(A_buf, B_ptr, C_buf):
desc_b = T.alloc_descriptor() tx, warp_n, warp_m = self.extract_thread_binding(thread_binding)
T.initialize_descriptor(desc_b, B_buf.access_ptr("w"), b_swizzle_mode,
int(b_leading_byte_offset >> 4), int(b_stride_byte_offset >> 4)) desc_b = T.alloc_wgmma_desc()
for ki in T.serial(0, (k_dim // micro_size_k)): T.initialize_wgmma_descriptor(desc_b, B_ptr, b_swizzle_mode,
for i in T.serial(m_dim // 64): int(b_leading_byte_offset >> 4),
k_dim_offset = ki * micro_size_k int(b_stride_byte_offset >> 4))
A_offset = ki * warp_rows * local_size_a + i * local_size_a T.warpgroup_fence_operand(A_buf, num_regs=a_regs)
B_offset = k_dim_offset if b_is_k_major else k_dim_offset * B_buf.shape[-1] T.warpgroup_fence_operand(C_buf, num_regs=accum_regs)
C_offset = i * warp_cols * local_size_out # 4 warps as an unit T.warpgroup_arrive()
T.ptx_wgmma_rs(
accum_dtype, for j in T.unroll(0, num_inst_n):
wgmma_prefix, for i in T.unroll(num_inst_m):
self.a_transposed, for ki in T.unroll(0, (k_dim // micro_size_k)):
not self.b_transposed, warp_j = warp_n * num_inst_n + j
a_dtype_abbrv, scale_out = T.Select(ki != 0, 1, T.Select(clear_accum, 0, 1))
b_dtype_abbrv,
accum_dtype_abbrv, A_offset = ki * warp_rows * local_size_a + i * local_size_a
A_buf.data, B_offset = (
A_offset, ki // bk_atom_size
desc_b.data, ) * n_dim * b_swizzle_atom_elems + warp_j * wgmma_inst_n * b_swizzle_atom_elems + (
(B_offset * elems_in_bytes) >> 4, ki % bk_atom_size) * micro_size_k if b_is_k_major else (
C_local_buf.data, ki * b_swizzle_atom_elems * micro_size_k + warp_j * wgmma_inst_n *
C_offset, (k_dim if n_dim // b_swizzle_atom_elems > 1 else 1))
scale_out, C_offset = i * warp_cols * local_size_out + j * warp_cols * local_size_out // num_inst_n # 4 warps as an unit
scale_in_a, T.ptx_wgmma_rs(
scale_in_b, accum_dtype,
) wgmma_prefix,
self.b_transposed,
return _warp_mma(A_buf, B_buf, C_local_buf) a_dtype_abbrv,
b_dtype_abbrv,
accum_dtype_abbrv,
A_buf.data,
A_offset,
desc_b.data,
(B_offset * elems_in_bytes) >> 4,
C_buf.data,
C_offset,
scale_out,
scale_in_a,
scale_in_b,
)
T.warpgroup_commit_batch()
if wg_wait >= 0:
T.warpgroup_wait(wg_wait)
T.warpgroup_fence_operand(C_buf, num_regs=accum_regs)
T.warpgroup_fence_operand(A_buf, num_regs=a_regs)
return _warp_mma(A_buf, B_ptr, C_buf)
def make_mma_load_layout(self, local_buf: Buffer, matrix: str = "A") -> T.Fragment: def make_mma_load_layout(self, local_buf: Buffer, matrix: str = "A") -> T.Fragment:
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment