Unverified Commit bbbf4207 authored by guchaoyang's avatar guchaoyang Committed by GitHub
Browse files

Merge branch 'main' into dcu

parents 8f4628e0 5eb30a4f
......@@ -4,17 +4,24 @@ This module provides functionality for auto-tuning tilelang programs, including
and performance optimization through configuration search.
"""
from __future__ import annotations
from dataclasses import dataclass
import tilelang
from tilelang import tvm as tvm
from tilelang.jit import JITImpl
from tilelang.jit.kernel import JITKernel
from tvm.tir import PrimFunc, Var
from tvm.target import Target
import inspect
from functools import partial
from typing import (Callable, Literal, Any, overload)
from tqdm import tqdm
from typing import (Callable, Generic, Literal, Any, TypeVar)
# Python 3.9 compatibility for ParamSpec
try:
from typing import ParamSpec
except ImportError: # Python < 3.10
from typing_extensions import ParamSpec
from tqdm.auto import tqdm
import logging
import functools
import concurrent.futures
import torch
import os
......@@ -30,7 +37,6 @@ from tilelang import env
from tilelang.autotuner.param import CompileArgs, ProfileArgs, AutotuneResult
from tilelang.autotuner.capture import get_autotune_inputs
from tilelang.utils.target import determine_target
from tilelang.jit.param import _P, _RProg
from tilelang import __version__
......@@ -524,12 +530,12 @@ class AutoTuner:
# latency, ref_latency = target_fn(jit_kernel)
latency, ref_latency = run_with_timeout(target_fn, timeout, jit_kernel)
except TimeoutException:
logger.info(
logger.warning(
f"A timeout occurred while testing config {config}, checkout autotuner.log for more details"
)
continue
except Exception:
logger.info(
logger.warning(
f"An error occurred while testing config {config}, checkout autotuner.log for more details"
)
logger.debug(f"Error: {traceback.format_exc()}")
......@@ -585,9 +591,13 @@ class AutoTuner:
return self.run()
class _AutoTunerImplementation:
# Overload __init__ to help type checkers understand the effect of return_program
# The '-> None' is for __init__ itself. The crucial part is Literal for return_program.
_P = ParamSpec('_P')
_T = TypeVar('_T')
@dataclass
class AutoTuneImpl(Generic[_P, _T]):
jit_impl: JITImpl
warmup: int = 25
rep: int = 100
......@@ -603,125 +613,51 @@ class _AutoTunerImplementation:
manual_check_prog: Callable = None
cache_input_tensors: bool = False
def __init__(self,
configs: dict | Callable,
warmup: int = 25,
rep: int = 100,
timeout: int = 100,
supply_type: tilelang.TensorSupplyType = tilelang.TensorSupplyType.Auto,
ref_prog: Callable = None,
supply_prog: Callable = None,
rtol: float = 1e-2,
atol: float = 1e-2,
max_mismatched_ratio: float = 0.01,
skip_check: bool = False,
manual_check_prog: Callable = None,
cache_input_tensors: bool = False) -> None:
"""Initialize the AutoTunerImplementation.
def __post_init__(self):
self._tuner_cache = {}
def get_tunner(self):
autotuner = AutoTuner(
self.jit_impl.func, configs=self.configs).set_profile_args(
supply_type=self.supply_type,
ref_prog=self.ref_prog,
supply_prog=self.supply_prog,
rtol=self.rtol,
atol=self.atol,
max_mismatched_ratio=self.max_mismatched_ratio,
skip_check=self.skip_check,
manual_check_prog=self.manual_check_prog,
cache_input_tensors=self.cache_input_tensors,
).set_compile_args(
out_idx=self.jit_impl.out_idx,
execution_backend=self.jit_impl.execution_backend,
target=self.jit_impl.target,
target_host=self.jit_impl.target_host,
verbose=self.jit_impl.verbose,
pass_configs=self.jit_impl.pass_configs,
)
autotuner.run = partial(autotuner.run, self.warmup, self.rep, self.timeout)
return autotuner
Args:
configs: Configuration space to explore during auto-tuning.
warmup: Number of warmup iterations before timing.
rep: Number of repetitions for timing measurements.
timeout: Maximum time (in seconds) allowed for each configuration.
supply_type: Strategy for generating input tensors (random/zeros/etc)
ref_prog: Reference implementation for validation
supply_prog: Custom function to provide input tensors
rtol: Relative tolerance for numerical validation
atol: Absolute tolerance for numerical validation
max_mismatched_ratio: Allowed percentage of mismatched values
skip_check: Bypass validation against reference implementation
manual_check_prog: Custom validation function
cache_input_tensors: Reuse input tensors across trials
"""
# Configuration and benchmarking parameters
self.configs = configs # Search space of tuning configurations
self.warmup = warmup # Warmup iterations for stable measurements
self.rep = rep # Measurement repetitions for statistics
self.timeout = timeout # Per-configuration timeout threshold
# Tensor handling and validation setup
self.supply_type = supply_type # Input tensor generation strategy
self.ref_prog = ref_prog # Ground truth implementation
self.supply_prog = supply_prog # Custom input data provider
self.rtol = rtol # Relative error tolerance
self.atol = atol # Absolute error tolerance
self.max_mismatched_ratio = max_mismatched_ratio # Allowed mismatch
# Validation control flags
self.skip_check = skip_check # Bypass accuracy verification
self.manual_check_prog = manual_check_prog # Custom validation
self.cache_input_tensors = cache_input_tensors # Reuse inputs
# Cache for storing tuned kernel implementations
self._tuner_cache: dict[tuple, tilelang.JITKernel] = {} # (args, kwargs) -> compiled kernel
# This tells the type checker what the *wrapper* function will return.
# this is for linting, please do not remove it.
@overload
def __call__(self, fn: Callable[_P, _RProg]) -> Callable[_P, tuple[_RProg, AutotuneResult]]:
...
@overload
def __call__(self, fn: Callable[_P, _RProg]) -> Callable[_P, AutotuneResult]:
...
# Actual implementation of __call__
def __call__(self, fn: Callable[_P, _RProg]) -> Callable[_P, Any]:
warmup = self.warmup
rep = self.rep
timeout = self.timeout
@functools.wraps(fn)
def wrapper(*args, **kwargs):
key_args_tuple = args
key_kwargs_tuple = tuple(sorted(kwargs.items()))
key = (key_args_tuple, key_kwargs_tuple)
if key not in self._tuner_cache:
def jit_compile(**config_arg):
return fn(*args, **kwargs, __tune_params=config_arg)
compile_arguments = fn(__return_compile_arguments=True)
autotuner = AutoTuner(
fn, configs=self.configs).set_profile_args(
supply_type=self.supply_type,
ref_prog=self.ref_prog,
supply_prog=self.supply_prog,
rtol=self.rtol,
atol=self.atol,
max_mismatched_ratio=self.max_mismatched_ratio,
skip_check=self.skip_check,
manual_check_prog=self.manual_check_prog,
cache_input_tensors=self.cache_input_tensors,
).set_compile_args(
out_idx=compile_arguments['out_idx'],
execution_backend=compile_arguments['execution_backend'],
target=compile_arguments['target'],
target_host=compile_arguments['target_host'],
verbose=compile_arguments['verbose'],
pass_configs=compile_arguments['pass_configs'],
)
autotuner.jit_compile = jit_compile
autotuner.set_kernel_parameters(key, inspect.signature(fn).parameters)
autotuner.run = partial(autotuner.run, warmup, rep, timeout)
artifact = autotuner.run()
self._tuner_cache[key] = artifact.kernel
return self._tuner_cache[key]
return wrapper
def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> JITKernel:
key_args_tuple = args
key_kwargs_tuple = tuple(sorted(kwargs.items()))
key = (key_args_tuple, key_kwargs_tuple)
if key not in self._tuner_cache:
def jit_compile(**config_arg):
return self.jit_impl(*args, **kwargs, __tune_params=config_arg)
autotuner = self.get_tunner()
autotuner.jit_compile = jit_compile
autotuner.set_kernel_parameters(key, self.jit_impl.signature.parameters)
artifact = autotuner.run()
self._tuner_cache[key] = artifact.kernel
return self._tuner_cache[key]
def autotune( # This is the new public interface
func: Callable[_P, _RProg] | PrimFunc | None = None,
func: Callable[_P, _T] | PrimFunc | None = None,
*, # Indicates subsequent arguments are keyword-only
configs: dict | Callable,
# profile arguments
......@@ -795,22 +731,26 @@ def autotune( # This is the new public interface
elif isinstance(func, PrimFunc):
raise ValueError("Use tilelang.jit to decorate prim_func is not supported yet.")
else:
# Case 2: Used as @autotune(...) to configure, or func_or_out_idx is meant as out_idx.
# Create a _AutoTunerImplementation instance with the provided/defaulted arguments.
# This instance is a decorator that will be applied to the function later.
configured_decorator = _AutoTunerImplementation(
configs=configs,
warmup=warmup,
rep=rep,
timeout=timeout,
supply_type=supply_type,
ref_prog=ref_prog,
supply_prog=supply_prog,
rtol=rtol,
atol=atol,
max_mismatched_ratio=max_mismatched_ratio,
skip_check=skip_check,
manual_check_prog=manual_check_prog,
cache_input_tensors=cache_input_tensors,
)
return configured_decorator
def decorator(impl):
assert isinstance(
impl, JITImpl
), "The @autotune decorator can only be applied to @tilelang.jit decorated instances."
return AutoTuneImpl(
jit_impl=impl,
configs=configs,
warmup=warmup,
rep=rep,
timeout=timeout,
supply_type=supply_type,
ref_prog=ref_prog,
supply_prog=supply_prog,
rtol=rtol,
atol=atol,
max_mismatched_ratio=max_mismatched_ratio,
skip_check=skip_check,
manual_check_prog=manual_check_prog,
cache_input_tensors=cache_input_tensors,
)
return decorator
......@@ -3,7 +3,7 @@ from __future__ import annotations
import functools
import math
from queue import PriorityQueue
from typing import Iterable
from collections.abc import Iterable
import numpy as np
import tvm
......
from __future__ import annotations
from typing import Mapping
from collections.abc import Mapping
from tvm.tir.schedule.schedule import BlockRV
from tvm.ir import structural_equal
from tvm import arith, tir
......
......@@ -64,7 +64,7 @@ def get_cc():
return None
@functools.lru_cache(maxsize=None)
@functools.cache
def get_cplus_compiler():
"""Return the path to the default C/C++ compiler.
......
......@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
"""Wrapping functions to bridge frameworks with DLPack support to TVM"""
from tvm.runtime import ndarray
from tvm import runtime
def convert_func(tvm_func, tensor_type, to_dlpack_func):
......@@ -49,9 +49,9 @@ def convert_func(tvm_func, tensor_type, to_dlpack_func):
torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2,
torch.float8_e5m2fnuz
}:
return ndarray.from_dlpack(to_dlpack_func(arg.view(torch.int8)))._create_view(
return runtime.from_dlpack(to_dlpack_func(arg.view(torch.int8)))._create_view(
arg.shape, dtype=float8_dtype_map[arg.dtype])
return ndarray.from_dlpack(to_dlpack_func(arg))
return runtime.from_dlpack(to_dlpack_func(arg))
return arg
def _wrapper(*args):
......
......@@ -9,7 +9,7 @@ from __future__ import absolute_import as _abs
import subprocess
import tvm.ffi
import tvm_ffi
from tvm.contrib import utils
from tvm.base import py_str
......@@ -97,7 +97,7 @@ def compile_hip(code,
return data
@tvm.ffi.register_func("tilelang_callback_hip_compile", override=True)
@tvm_ffi.register_global_func("tilelang_callback_hip_compile", override=True)
def tilelang_callback_hip_compile(code, target):
"""use hipcc to generate fatbin code for better optimization"""
hsaco = compile_hip(code, target_format="hsaco")
......
......@@ -7,9 +7,12 @@ from __future__ import annotations
import os
import subprocess
import warnings
from tilelang.env import CUDA_HOME
import tvm.ffi
import contextlib
from tilelang.env import CUDA_HOME, CUTLASS_INCLUDE_DIR, TILELANG_TEMPLATE_PATH
import shutil
import tempfile
import tvm_ffi
from tilelang import tvm as tvm
from tvm.target import Target
from tvm.base import py_str
......@@ -125,6 +128,154 @@ def compile_cuda(code,
return data
def default_compile_options(compile_flags: list[str] | None = None) -> list[str]:
"""
Build a set of default NVCC compile options for TileLang generated sources.
Includes C++ standard and common include paths (TileLang templates, CUTLASS,
CUDA include). Merges user-provided compile flags if given.
Parameters
----------
compile_flags : Optional[List[str]]
Additional flags to include. Items are split on whitespace.
Returns
-------
List[str]
A list of flags suitable for NVCC's command line.
"""
options: list[str] = ["-std=c++17"]
try:
if TILELANG_TEMPLATE_PATH:
options.append(f"-I{TILELANG_TEMPLATE_PATH}")
except Exception:
pass
try:
if CUTLASS_INCLUDE_DIR:
options.append(f"-I{CUTLASS_INCLUDE_DIR}")
except Exception:
pass
try:
if CUDA_HOME:
options.append(f"-I{os.path.join(CUDA_HOME, 'include')}")
except Exception:
pass
# Preserve user flags exactly, including repeated tokens required by NVCC
# (e.g., multiple "-gencode" pairs or repeated "-Xcompiler" entries).
if compile_flags:
import shlex
for flag in compile_flags:
# Split each string like a shell would, preserving quoted args
tokens = shlex.split(flag) if isinstance(flag, str) else [str(flag)]
options.extend(tokens)
return options
def get_ptx_from_source(code: str,
compile_flags: list[str] | None = None,
verbose: bool = False) -> str:
"""
Compile CUDA C++ source to PTX using NVCC and return as text.
Parameters
----------
code : str
CUDA C++ kernel source code.
compile_flags : Optional[List[str]]
Additional flags merged with defaults.
verbose : bool
Print NVCC output when True.
Returns
-------
str
PTX text.
"""
opts = default_compile_options(compile_flags)
ptx_bytes = compile_cuda(code, target_format="ptx", options=opts, verbose=verbose)
try:
return ptx_bytes.decode("utf-8")
except Exception:
return str(ptx_bytes)
def _find_tool(name: str) -> str | None:
"""Find a CUDA binary in PATH or under CUDA_HOME/bin."""
path = shutil.which(name)
if path:
return path
if CUDA_HOME:
candidate = os.path.join(CUDA_HOME, "bin", name)
if os.path.exists(candidate):
return candidate
return None
def get_sass_from_source(code: str,
compile_flags: list[str] | None = None,
verbose: bool = False) -> str:
"""
Compile CUDA C++ source to CUBIN and disassemble to SASS.
Uses nvdisasm if available; otherwise falls back to cuobjdump.
Parameters
----------
code : str
CUDA C++ kernel source code.
compile_flags : Optional[List[str]]
Additional flags merged with defaults.
verbose : bool
Print tool outputs when True.
Returns
-------
str
SASS text.
"""
opts = default_compile_options(compile_flags)
cubin_bytes = compile_cuda(code, target_format="cubin", options=opts, verbose=verbose)
# Write to a temp .cubin file
with tempfile.NamedTemporaryFile(suffix=".cubin", delete=False) as tmp:
tmp.write(cubin_bytes)
cubin_path = tmp.name
# Try disassembly tools (prefer nvdisasm, fallback cuobjdump)
cand_nvdisasm = _find_tool("nvdisasm")
cand_cuobjdump = _find_tool("cuobjdump")
if not cand_nvdisasm and not cand_cuobjdump:
raise RuntimeError(
"Cannot find 'nvdisasm' or 'cuobjdump'. Please ensure CUDA toolkit is installed and in PATH."
)
last_err: str | None = None
try:
# Attempt nvdisasm first
tools_to_try = []
if cand_nvdisasm:
tools_to_try.append(("nvdisasm", [cand_nvdisasm, cubin_path]))
if cand_cuobjdump:
tools_to_try.append(("cuobjdump", [cand_cuobjdump, "--dump-sass", cubin_path]))
for tool_name, cmd in tools_to_try:
proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
out, _ = proc.communicate()
text = py_str(out)
if verbose:
print(f"[{tool_name}] output:\n{text}")
if proc.returncode == 0 and text.strip():
return text
last_err = f"{tool_name} rc={proc.returncode}, output:\n{text}"
# If we reach here, all attempts failed
raise RuntimeError(f"SASS disassembly failed. Tried tools: "
f"{', '.join(name for name, _ in tools_to_try)}\n{last_err or ''}")
finally:
with contextlib.suppress(Exception):
os.remove(cubin_path)
def find_cuda_path():
"""Utility function to find cuda path
......@@ -182,14 +333,14 @@ def get_cuda_version(cuda_path=None):
raise RuntimeError("Cannot read cuda version file")
@tvm.ffi.register_func("tilelang_callback_cuda_compile", override=True)
@tvm_ffi.register_global_func("tilelang_callback_cuda_compile", override=True)
def tilelang_callback_cuda_compile(code, target): # pylint: disable=unused-argument
"""use nvcc to generate fatbin code for better optimization"""
ptx = compile_cuda(code, target_format="fatbin")
return ptx
@tvm.ffi.register_func("tilelang_callback_libdevice_path", override=True)
@tvm_ffi.register_global_func("tilelang_callback_libdevice_path", override=True)
def find_libdevice_path(arch):
"""Utility function to find libdevice
......@@ -254,7 +405,7 @@ def callback_libdevice_path(arch):
return ""
@tvm.ffi.register_func("tvm.contrib.nvcc.get_compute_version", override=True)
@tvm_ffi.register_global_func("tvm.contrib.nvcc.get_compute_version", override=True)
def get_target_compute_version(target=None):
"""Utility function to get compute capability of compilation target.
......@@ -400,7 +551,7 @@ def have_cudagraph():
return False
@tvm.ffi.register_func("tvm.contrib.nvcc.supports_bf16", override=True)
@tvm_ffi.register_global_func("tvm.contrib.nvcc.supports_bf16", override=True)
def have_bf16(compute_version):
"""Either bf16 support is provided in the compute capability or not
......@@ -413,7 +564,7 @@ def have_bf16(compute_version):
return major >= 8
@tvm.ffi.register_func("tvm.contrib.nvcc.supports_fp8", override=True)
@tvm_ffi.register_global_func("tvm.contrib.nvcc.supports_fp8", override=True)
def have_fp8(compute_version):
"""Whether fp8 support is provided in the specified compute capability or not
......@@ -430,7 +581,7 @@ def have_fp8(compute_version):
return any(conditions)
@tvm.ffi.register_func("tvm.contrib.nvcc.supports_tma", override=True)
@tvm_ffi.register_global_func("tvm.contrib.nvcc.supports_tma", override=True)
def have_tma(target):
"""Whether TMA support is provided in the specified compute capability or not
......
......@@ -21,7 +21,7 @@ import subprocess
import os
from os.path import join, exists
import tvm.ffi
import tvm_ffi
from tvm.base import py_str
import tvm.runtime
import tvm.target
......@@ -100,7 +100,7 @@ def rocm_link(in_file, out_file, lld=None):
raise RuntimeError(msg)
@tvm.ffi.register_func("tvm_callback_rocm_link", override=True)
@tvm_ffi.register_global_func("tvm_callback_rocm_link", override=True)
def callback_rocm_link(obj_bin):
"""Links object file generated from LLVM to HSA Code Object
......@@ -124,7 +124,7 @@ def callback_rocm_link(obj_bin):
return cobj_bin
@tvm.ffi.register_func("tvm_callback_rocm_bitcode_path", override=True)
@tvm_ffi.register_global_func("tvm_callback_rocm_bitcode_path", override=True)
def callback_rocm_bitcode_path(rocdl_dir=None):
"""Utility function to find ROCm device library bitcodes
......@@ -226,8 +226,11 @@ def have_matrixcore(compute_version=None):
return False
@tvm.ffi.register_func("tvm_callback_rocm_get_arch", override=True)
def get_rocm_arch(rocm_path="/opt/dtk"):
@tvm_ffi.register_global_func("tvm_callback_rocm_get_arch", override=True)
def get_rocm_arch(rocm_path="/opt/rocm"):
# @tvm.ffi.register_func("tvm_callback_rocm_get_arch", override=True)
# def get_rocm_arch(rocm_path="/opt/dtk"):
"""Utility function to get the AMD GPU architecture
Parameters
......
from __future__ import annotations
from typing import Callable
from tvm import register_func
import tvm_ffi
from tvm.target import Target
......@@ -12,7 +12,7 @@ def register_cuda_postproc(func: Callable[[str, Target], str], override: bool =
and returns the processed code (str).
override: Whether to override existing registered function. Defaults to True.
"""
register_func("tilelang_callback_cuda_postproc", f=func, override=override)
tvm_ffi.register_global_func("tilelang_callback_cuda_postproc", f=func, override=override)
def register_hip_postproc(func: Callable[[str, Target], str], override: bool = True):
......@@ -23,7 +23,7 @@ def register_hip_postproc(func: Callable[[str, Target], str], override: bool = T
and returns the processed code (str).
override: Whether to override existing registered function. Defaults to True.
"""
register_func("tilelang_callback_hip_postproc", f=func, override=override)
tvm_ffi.register_global_func("tilelang_callback_hip_postproc", f=func, override=override)
def register_cuda_postproc_callback(func: Callable | bool = None, override: bool = True):
......
......@@ -7,6 +7,7 @@ from typing import Callable
import tilelang.transform
from tilelang import tvm as tvm
from tvm import tir
import tvm_ffi
from tvm.ir import CallingConv
from tvm.target import Target
from tilelang.contrib import hipcc, nvcc
......@@ -52,7 +53,7 @@ def get_host_call(is_device_c: bool = False) -> Callable[[tir.PrimFunc], bool]:
return lambda func: not get_device_call(is_device_c)(func)
@tvm.register_func("tilelang_callback_cuda_compile", override=True)
@tvm_ffi.register_global_func("tilelang_callback_cuda_compile", override=True)
def tilelang_callback_cuda_compile(code, target):
project_root = osp.join(osp.dirname(__file__), "../..")
if "TL_TEMPLATE_PATH" in os.environ:
......@@ -89,7 +90,7 @@ def tilelang_callback_cuda_compile(code, target):
return ptx
@tvm.register_func("tilelang_callback_hip_compile", override=True)
@tvm_ffi.register_global_func("tilelang_callback_hip_compile", override=True)
def tilelang_callback_hip_compile(code, target):
project_root = osp.join(osp.dirname(__file__), "../..")
tl_template_path = osp.abspath(osp.join(project_root, "src"))
......@@ -182,7 +183,7 @@ def device_codegen_without_compile(device_mod: tvm.IRModule, target: Target) ->
elif target.kind.name == "llvm":
device_mod = tvm.ffi.get_global_func("target.build.llvm")(device_mod, target)
elif target.kind.name == "webgpu":
device_mod = tvm.ffi.get_global_func("target.build.tilelang_webgpu")(device_mod, target)
device_mod = tvm.ffi.get_global_func("target.build.webgpu")(device_mod, target)
elif target.kind.name == "metal":
device_mod = tvm.ffi.get_global_func("target.build.metal")(device_mod, target)
else:
......@@ -241,6 +242,6 @@ def lower(
host_mod = host_codegen(host_mod, target_host)
host_mod.import_module(codegen_mod)
return CompiledArtifact(
host_mod, device_mod, params, codegen_mod.get_source(), rt_mod=host_mod)
host_mod, device_mod, params, codegen_mod.inspect_source(), rt_mod=host_mod)
return CompiledArtifact(host_mod, device_mod, params, codegen_mod.get_source())
return CompiledArtifact(host_mod, device_mod, params, codegen_mod.inspect_source())
......@@ -96,6 +96,8 @@ def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule:
mod = tilelang.transform.LetInline()(mod)
# Add wrapper for single buf store
mod = tilelang.transform.AddWrapperForSingleBufStore()(mod)
# Normalize negative indices to canonical non-negative form
mod = tilelang.transform.LegalizeNegativeIndex()(mod)
# Inject assumes to speedup tvm prover
mod = tilelang.transform.InjectAssumes()(mod)
# Simplify the IR expressions
......@@ -118,8 +120,6 @@ def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule:
# TODO(lei): return to tir pass when kSymbolicBound simplification
# is merged into tvm.
mod = tilelang.transform.Simplify()(mod)
# Try to vectorize loop with dynamic shape
mod = tilelang.transform.LoopVectorizeDynamic()(mod)
return mod
......
......@@ -236,6 +236,10 @@ class Environment:
"1") # print kernel name on compile
TILELANG_CLEAR_CACHE = EnvVar("TILELANG_CLEAR_CACHE", "0") # clear cache automatically if set
# Kernel selection options
# Default to GEMM v2; set to "1"/"true"/"yes"/"on" to force v1
TILELANG_USE_GEMM_V1 = EnvVar("TILELANG_USE_GEMM_V1", "0")
# Auto-tuning settings
TILELANG_AUTO_TUNING_CPU_UTILITIES = EnvVar("TILELANG_AUTO_TUNING_CPU_UTILITIES",
"0.9") # percent of CPUs used
......@@ -274,6 +278,14 @@ class Environment:
def is_print_on_compilation_enabled(self) -> bool:
return self.TILELANG_PRINT_ON_COMPILATION.lower() in ("1", "true", "yes", "on")
def use_gemm_v1(self) -> bool:
"""Return True if GEMM v1 should be used based on env.
Controlled by `TILELANG_USE_GEMM_V1`. Truthy values are one of
{"1", "true", "yes", "on"} (case-insensitive).
"""
return str(self.TILELANG_USE_GEMM_V1).lower() in ("1", "true", "yes", "on")
# Instantiate as a global configuration object
env = Environment()
......@@ -297,12 +309,11 @@ def prepend_pythonpath(path):
if env.TVM_IMPORT_PYTHON_PATH is not None:
prepend_pythonpath(env.TVM_IMPORT_PYTHON_PATH)
else:
tvm_path = os.path.join(THIRD_PARTY_ROOT, "tvm")
tvm_path = os.path.join(THIRD_PARTY_ROOT, 'tvm', 'python')
assert os.path.exists(tvm_path), tvm_path
if tvm_path not in sys.path:
tvm_python_binding = os.path.join(tvm_path, 'python')
prepend_pythonpath(tvm_python_binding)
env.TVM_IMPORT_PYTHON_PATH = tvm_python_binding
prepend_pythonpath(tvm_path)
env.TVM_IMPORT_PYTHON_PATH = tvm_path
if os.environ.get("TVM_LIBRARY_PATH") is None:
os.environ['TVM_LIBRARY_PATH'] = env.TVM_LIBRARY_PATH = os.pathsep.join(TL_LIBS)
......
......@@ -2,10 +2,32 @@ from __future__ import annotations
from tilelang import tvm as tvm
import tilelang.language as T
from tvm import DataType
from tvm.tir import PrimExpr
from tvm.tir import PrimExpr, IndexMap, Buffer, Var, BufferRegion
from tvm.runtime import convert
from .utils import (
mfma_store_index_map,)
from typing import Literal, Callable
from tilelang.utils import is_fragment
from tilelang.utils.language import to_buffer_region
from .mfma_layout import (
shared_16x4_to_local_64x1_layout_A,
shared_4x16_to_local_64x1_layout_B,
shared_16x16_to_local_64x4_layout_A,
shared_16x16_to_local_64x4_layout_B,
shared_16x32_to_local_64x8_layout_A,
shared_16x32_to_local_64x8_layout_B,
shared_16x64_to_local_64x16_layout_A,
shared_16x64_to_local_64x16_layout_B,
thread_id_shared_access_64x1_to_16x4_layout_A,
thread_id_shared_access_64x1_to_4x16_layout_B,
thread_id_shared_access_64x4_to_16x16_layout_A,
thread_id_shared_access_64x4_to_16x16_layout_B,
thread_id_shared_access_64x8_to_16x32_layout_A,
thread_id_shared_access_64x8_to_16x32_layout_B,
thread_id_shared_access_64x16_to_16x64_layout_A,
thread_id_shared_access_64x16_to_16x64_layout_B,
)
lift = convert
......@@ -53,6 +75,7 @@ class MatrixCoreIntrinEmitter:
k_pack: int | None = None,
is_m_first: bool | None = False,
b_preshuffle: bool | None = False,
thread_var: Var | None = None,
):
self.a_dtype = a_dtype
self.b_dtype = b_dtype
......@@ -79,6 +102,7 @@ class MatrixCoreIntrinEmitter:
self.reduce_k = reduce_k
self.threads = (self.WARP_SIZE * (block_row_warps * block_col_warps) * reduce_k)
self.num_elems_per_byte = num_elems_per_byte
self.thread_var = thread_var
def _initialize_k_dim(self, a_dtype="float16"):
if isinstance(a_dtype, str):
......@@ -115,6 +139,7 @@ class MatrixCoreIntrinEmitter:
}[out_dtype]
in_dtype_abbrv = {
"bfloat16": "bf16",
"float16": "f16",
"float32": "f32",
"int8": "i8",
......@@ -126,6 +151,9 @@ class MatrixCoreIntrinEmitter:
self.mfma_suffix = f"{out_dtype_abbrv}_{M_DIM}x{N_DIM}x{k_dim}_fp8_fp8"
elif in_dtype_abbrv == "i8":
self.mfma_suffix = f"{out_dtype_abbrv}_{M_DIM}x{N_DIM}x{k_dim}_i8"
elif in_dtype_abbrv == "bf16":
# HIP intrinsic uses ...x{K}bf16_1k without an underscore before bf16
self.mfma_suffix = f"{out_dtype_abbrv}_{M_DIM}x{N_DIM}x{k_dim}bf16_1k"
else:
self.mfma_suffix = f"{out_dtype_abbrv}_{M_DIM}x{N_DIM}x{k_dim}{in_dtype_abbrv}"
......@@ -147,24 +175,6 @@ class MatrixCoreIntrinEmitter:
self.b_preshuffle = b_preshuffle
def get_ldmatrix_index_map(self, is_b=False):
from .mfma_layout import (
shared_16x4_to_local_64x1_layout_A,
shared_4x16_to_local_64x1_layout_B,
shared_16x16_to_local_64x4_layout_A,
shared_16x16_to_local_64x4_layout_B,
shared_16x32_to_local_64x8_layout_A,
shared_16x32_to_local_64x8_layout_B,
shared_16x64_to_local_64x16_layout_A,
shared_16x64_to_local_64x16_layout_B,
thread_id_shared_access_64x1_to_16x4_layout_A,
thread_id_shared_access_64x1_to_4x16_layout_B,
thread_id_shared_access_64x4_to_16x16_layout_A,
thread_id_shared_access_64x4_to_16x16_layout_B,
thread_id_shared_access_64x8_to_16x32_layout_A,
thread_id_shared_access_64x8_to_16x32_layout_B,
thread_id_shared_access_64x16_to_16x64_layout_A,
thread_id_shared_access_64x16_to_16x64_layout_B,
)
k_dim = self.k_dim * self.k_pack
transposed = self.a_transposed if not is_b else self.b_transposed
......@@ -200,6 +210,22 @@ class MatrixCoreIntrinEmitter:
return index_map, reverse_index_map
def get_store_index_map(self, inverse: bool = False) -> IndexMap:
warp_size, local_size_c = self.WARP_SIZE, self.local_size_out
index_map = IndexMap.from_func(mfma_store_index_map, index_dtype="int32")
if not inverse:
return index_map
inverse_index_map = index_map.inverse([warp_size, local_size_c])
return inverse_index_map
def get_thread_binding(self):
if self.thread_var is None:
current_frame = T.KernelLaunchFrame.Current()
assert current_frame is not None, "Must be called in a T.Kernel Frame"
return current_frame.get_thread_binding()
else:
return self.thread_var
def extract_thread_binding(self,
thread_id,
is_m_first=None) -> tuple[PrimExpr, PrimExpr, PrimExpr]:
......@@ -229,7 +255,7 @@ class MatrixCoreIntrinEmitter:
(WARP_SIZE * block_row_warps)) % block_col_warps,
return lane_id, warp_n, warp_m
def ldmatrix_a(self, A_local_buf, A_shared_buf, ki, rk=0):
def ldmatrix_a(self, A_local_buf, A_shared_buf: Buffer | BufferRegion, ki, rk=0):
warp_row_tiles = self.warp_row_tiles
warp_rows = self.warp_rows
chunk = self.chunk
......@@ -238,10 +264,15 @@ class MatrixCoreIntrinEmitter:
local_size_a = self.local_size_a
k_pack = self.k_pack
is_transposed = self.a_transposed
current_frame = T.KernelLaunchFrame.Current()
thread_binding = current_frame.get_thread_binding()
thread_binding = self.get_thread_binding()
_, reverse_index_map = self.get_ldmatrix_index_map(is_b=False)
# legalize shared buffer to region
A_region = to_buffer_region(A_shared_buf)
A_buf = A_region.buffer
A_base0 = A_region.region[-2].min
A_base1 = A_region.region[-1].min
@T.macro
def _warp_ldmatrix_a(
A_local_buf,
......@@ -257,20 +288,20 @@ class MatrixCoreIntrinEmitter:
row, col = T.meta_var(reverse_index_map(tx, local_id))
l, r = (rk * chunk + ki * (k_pack * micro_size_k),
warp_m * warp_row_tiles + i * micro_size_x)
A_local_buf[i * k_pack * local_size_a + local_id] = A_shared_buf[l + row,
r + col]
A_local_buf[i * k_pack * local_size_a + local_id] = A_buf[A_base0 + l + row,
A_base1 + r + col]
else:
for i in T.serial(warp_rows):
for local_id in T.vectorized(k_pack * local_size_a):
row, col = T.meta_var(reverse_index_map(tx, local_id))
l, r = (warp_m * warp_row_tiles + i * micro_size_x,
rk * chunk + ki * (k_pack * micro_size_k))
A_local_buf[i * k_pack * local_size_a + local_id] = A_shared_buf[l + row,
r + col]
A_local_buf[i * k_pack * local_size_a + local_id] = A_buf[A_base0 + l + row,
A_base1 + r + col]
return _warp_ldmatrix_a(A_local_buf, A_shared_buf, ki, thread_binding, rk)
def ldmatrix_b(self, B_local_buf, B_shared_buf, ki, rk=0):
def ldmatrix_b(self, B_local_buf, B_shared_buf: Buffer | BufferRegion, ki, rk=0):
warp_col_tiles = self.warp_col_tiles
warp_cols = self.warp_cols
chunk = self.chunk
......@@ -279,10 +310,15 @@ class MatrixCoreIntrinEmitter:
local_size_b = self.local_size_b
k_pack = self.k_pack
is_transposed = self.b_transposed
current_frame = T.KernelLaunchFrame.Current()
thread_binding = current_frame.get_thread_binding()
thread_binding = self.get_thread_binding()
_, reverse_index_map = self.get_ldmatrix_index_map(is_b=True)
# legalize shared buffer to region
B_region = to_buffer_region(B_shared_buf)
B_buf = B_region.buffer
B_base0 = B_region.region[-2].min
B_base1 = B_region.region[-1].min
@T.macro
def _warp_ldmatrix_b(
B_local_buf,
......@@ -300,8 +336,8 @@ class MatrixCoreIntrinEmitter:
warp_n * warp_col_tiles + j * micro_size_y,
rk * chunk + ki * (k_pack * micro_size_k),
)
B_local_buf[j * k_pack * local_size_b + local_id] = B_shared_buf[l + row,
r + col]
B_local_buf[j * k_pack * local_size_b + local_id] = B_buf[B_base0 + l + row,
B_base1 + r + col]
else:
for j in T.serial(warp_cols):
......@@ -311,12 +347,16 @@ class MatrixCoreIntrinEmitter:
rk * chunk + ki * (k_pack * micro_size_k),
warp_n * warp_col_tiles + j * micro_size_y,
)
B_local_buf[j * k_pack * local_size_b + local_id] = B_shared_buf[l + row,
r + col]
B_local_buf[j * k_pack * local_size_b + local_id] = B_buf[B_base0 + l + row,
B_base1 + r + col]
return _warp_ldmatrix_b(B_local_buf, B_shared_buf, ki, thread_binding, rk)
def mfma(self, A_local_buf, B_local_buf, C_local_buf):
def mfma(self,
A_local_buf: Buffer,
B_local_buf: Buffer,
C_local_buf: Buffer,
k_inner: PrimExpr | None = 0):
warp_rows = self.warp_rows
warp_cols = self.warp_cols
local_size_a = self.local_size_a
......@@ -329,8 +369,13 @@ class MatrixCoreIntrinEmitter:
compute_b_dtype = b_dtype if local_size_b == 1 else f"{b_dtype}x{local_size_b}"
compute_out_dtype = out_dtype if local_size_out == 1 else f"{out_dtype}x{local_size_out}"
a_is_fragment = is_fragment(A_local_buf)
b_is_fragment = is_fragment(B_local_buf)
a_local_stride: PrimExpr = k_inner * warp_rows * local_size_a if a_is_fragment else 0
b_local_stride: PrimExpr = k_inner * warp_cols * local_size_b if b_is_fragment else 0
@T.macro
def _warp_mma(A_local_buf, B_local_buf, C_local_buf):
def _warp_mfma(A_local_buf, B_local_buf, C_local_buf):
for kp, i, j in T.grid(k_pack, warp_rows, warp_cols):
T.tvm_mfma(
mfma_suffix,
......@@ -340,15 +385,15 @@ class MatrixCoreIntrinEmitter:
compute_b_dtype,
compute_out_dtype,
B_local_buf.data,
((j * k_pack + kp) * local_size_b) // local_size_b,
(b_local_stride + (j * k_pack + kp) * local_size_b) // local_size_b,
A_local_buf.data,
((i * k_pack + kp) * local_size_a) // local_size_a,
(a_local_stride + (i * k_pack + kp) * local_size_a) // local_size_a,
C_local_buf.data,
(i * warp_cols * local_size_out + j * local_size_out) // local_size_out,
dtype=compute_out_dtype,
)
return _warp_mma(A_local_buf, B_local_buf, C_local_buf)
return _warp_mfma(A_local_buf, B_local_buf, C_local_buf)
def stmatrix(self, C_local_buf, C_buf, pid_m=None, pid_n=None):
block_row_warps = self.block_row_warps
......@@ -356,8 +401,7 @@ class MatrixCoreIntrinEmitter:
warp_rows = self.warp_rows
warp_cols = self.warp_cols
local_size_out = self.local_size_out
current_frame = T.KernelLaunchFrame.Current()
thread_binding = current_frame.get_thread_binding()
thread_binding = self.get_thread_binding()
is_global = pid_m is not None and pid_n is not None
BLOCK_M = block_row_warps * warp_rows
BLOCK_N = block_col_warps * warp_cols
......@@ -366,7 +410,7 @@ class MatrixCoreIntrinEmitter:
assert C_buf_dims in {2, 4}, "C_buf should be 2D or 4D"
# STS
# MMA Store must be in simulated instead of TVM Intrins
# MFMA Store must be in simulated instead of TVM Intrins
# As TVM Intrins is like a hack that the threadIdx.x should be always
# equal to the warp_size
@T.macro
......@@ -400,6 +444,217 @@ class MatrixCoreIntrinEmitter:
thread_binding) if is_global else _warp_stmatrix_shared(
C_local_buf, C_buf, thread_binding)
def make_mfma_load_layout(self,
local_buf: Buffer,
matrix: Literal["A", "B"] = "A") -> T.Fragment:
"""
Create a layout function for storing MFMA results into a fragment buffer.
Parameters
----------
local_buf : tir.Buffer
The local buffer representing a fragment of a matrix.
Returns
-------
T.Fragment
A fragment object that describes how threads and indices
in `local_buf` are laid out.
Raises
------
AssertionError
If `local_buf` is not detected to be a fragment buffer.
"""
from tilelang.utils import is_fragment
assert matrix in ["A", "B"], "matrix should be either A or B"
matrix_is_a: bool = matrix == "A"
matrix_is_b: bool = matrix == "B"
transposed = self.a_transposed if matrix_is_a else self.b_transposed
# s represents spatial axis
# r represents reduction axis
# sr represents the two dims are spatial + reduction
# rs represents the two dims are reduction + spatial
# sr also can represent a non-transposed basic layout
# then rs also can represent a transposed basic layout
transform_func_sr_a: Callable = None
transform_func_sr_b: Callable = None
k_dim = self.k_dim * self.k_pack
if k_dim == 4:
transform_func_sr_a = shared_16x4_to_local_64x1_layout_A
transform_func_sr_b = shared_16x4_to_local_64x1_layout_A
elif k_dim == 16:
transform_func_sr_a = shared_16x16_to_local_64x4_layout_A
transform_func_sr_b = shared_16x16_to_local_64x4_layout_A
elif k_dim == 32:
transform_func_sr_a = shared_16x32_to_local_64x8_layout_A
transform_func_sr_b = shared_16x32_to_local_64x8_layout_A
elif k_dim == 64:
transform_func_sr_a = shared_16x64_to_local_64x16_layout_A
transform_func_sr_b = shared_16x64_to_local_64x16_layout_A
else:
raise ValueError("k_dim must be 4 or 16 or 32 or 64 currently")
is_sr_conditions = [False]
is_sr_conditions.append(matrix_is_a and not transposed)
is_sr_conditions.append(matrix_is_b and transposed)
is_sr_axis_order = any(is_sr_conditions)
transform_func: Callable = None
if matrix_is_a:
transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a(
j, i)
elif matrix_is_b:
transform_func = transform_func_sr_b if is_sr_axis_order else lambda i, j: transform_func_sr_b(
j, i)
else:
raise ValueError(f"Unsupported matrix {matrix}")
assert is_fragment(local_buf), f"local_buf must be a fragment, but got {local_buf.scope()}"
if matrix_is_a:
micro_size_s, micro_size_r = self.micro_size_x, self.micro_size_k
else:
micro_size_r, micro_size_s = self.micro_size_k, self.micro_size_y
block_row_warps, block_col_warps = (
self.block_row_warps,
self.block_col_warps,
)
inverse_mfma_load_layout = IndexMap.from_func(transform_func, index_dtype="int32")
def forward_thread(i: int, j: int) -> int:
"""
Given the row index `i` and column index `j` in the fragment,
"""
lane_id, _ = inverse_mfma_load_layout.map_indices([i, j])
return lane_id
def forward_index(i: int, j: int) -> int:
"""
Given the row index `i` and column index `j` in the fragment,
"""
_, local_id = inverse_mfma_load_layout.map_indices([i, j])
return local_id
base_fragment = T.Fragment(
[micro_size_s, micro_size_r] if is_sr_axis_order else [micro_size_r, micro_size_s],
forward_thread_fn=forward_thread,
forward_index_fn=forward_index,
)
warp_rows, warp_cols = self.warp_rows, self.warp_cols
chunk = self.chunk
warp_s = warp_rows if matrix_is_a else warp_cols
warp_r = chunk // micro_size_r
block_s = block_row_warps if matrix_is_a else block_col_warps
replicate = block_col_warps if matrix_is_a else block_row_warps
if is_sr_axis_order:
warp_fragment = base_fragment.repeat([warp_s, warp_r],
repeat_on_thread=False,
lower_dim_first=False)
if matrix_is_a:
block_fragment = warp_fragment.repeat([block_s, 1],
repeat_on_thread=True,
lower_dim_first=True).replicate(replicate)
elif matrix_is_b:
block_fragment = warp_fragment.replicate(replicate).repeat([block_s, 1],
repeat_on_thread=True,
lower_dim_first=True)
else:
raise ValueError(f"Unsupported matrix type {matrix}")
else:
warp_fragment = base_fragment.repeat([warp_r, warp_s],
repeat_on_thread=False,
lower_dim_first=True)
if matrix_is_a:
block_fragment = warp_fragment.repeat([1, block_s],
repeat_on_thread=True,
lower_dim_first=True).replicate(replicate)
elif matrix_is_b:
block_fragment = warp_fragment.replicate(replicate).repeat([1, block_s],
repeat_on_thread=True,
lower_dim_first=True)
else:
raise ValueError(f"Unsupported matrix type {matrix}")
return block_fragment
def make_mfma_store_layout(self, local_buf: Buffer) -> T.Fragment:
"""
Create a layout function for storing MFMA results into a fragment buffer.
Parameters
----------
local_buf : tir.Buffer
The local buffer representing a fragment of a matrix.
Returns
-------
T.Fragment
A fragment object that describes how threads and indices
in `local_buf` are laid out.
Raises
------
AssertionError
If `local_buf` is not detected to be a fragment buffer.
"""
from tilelang.utils import is_fragment
shape = local_buf.shape
inverse_mfma_store_layout = self.get_store_index_map(inverse=True)
assert is_fragment(local_buf), "local_buf must be a fragment"
micro_size_x, micro_size_y = self.micro_size_x, self.micro_size_y
local_size_out = self.local_size_out
block_row_warps, block_col_warps = self.block_row_warps, self.block_col_warps
warp_rows, warp_cols = self.warp_rows, self.warp_cols
warp_size = self.WARP_SIZE
is_m_first = self.is_m_first
def forward_thread(i: int, j: int) -> int:
"""
Given the row index `i` and column index `j` in the fragment,
map them to a thread index according to `inverse_mfma_store_layout`.
"""
# the upper bounds of i and j are block_row_warps * warp_rows * micro_size_x and block_col_warps * warp_cols * micro_size_y
# the upper bounds of block_row_warps and block_col_warps are warp_rows and warp_cols
block_i, block_j = (i // micro_size_x) // warp_rows, (j // micro_size_y) // warp_cols
# upper bounds of mfma_i and mfma_j are micro_size_x and micro_size_y
mfma_i, mfma_j = i % micro_size_x, j % micro_size_y
lane_id, _ = inverse_mfma_store_layout.map_indices([mfma_i, mfma_j])
if is_m_first:
thread_id = block_i * (block_col_warps * warp_cols) + block_j * warp_size + lane_id
else:
thread_id = block_j * (block_row_warps * warp_size) + block_i * warp_size + lane_id
return thread_id
def forward_index(i: int, j: int) -> int:
"""
Given the row index `i` and column index `j` in the fragment,
map them to a local index in a single thread according
to `inverse_mfma_store_layout`.
"""
# the upper bounds of i and j are block_row_warps * warp_rows * micro_size_x and block_col_warps * warp_cols * micro_size_y
# the upper bounds of warp_i and warp_j are warp_rows and warp_cols
warp_i, warp_j = (i // micro_size_x) % warp_rows, (j // micro_size_y) % warp_cols
# upper bounds of mfma_i and mfma_j are micro_size_x and micro_size_y
mfma_i, mfma_j = i % micro_size_x, j % micro_size_y
_, local_id = inverse_mfma_store_layout.map_indices([mfma_i, mfma_j])
return warp_i * (warp_cols * local_size_out) + warp_j * local_size_out + local_id
return T.Fragment(
shape,
forward_thread_fn=forward_thread,
forward_index_fn=forward_index,
)
class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter):
......@@ -421,34 +676,27 @@ class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter):
is_m_first: bool | None = False,
a_preshuffle: bool | None = False,
b_preshuffle: bool | None = False,
thread_var: Var | None = None,
):
self.a_dtype = a_dtype
self.b_dtype = b_dtype
self.accum_dtype = accum_dtype
self.a_transposed = a_transposed
self.b_transposed = b_transposed
# Hint Information
self.block_row_warps = block_row_warps
self.block_col_warps = block_col_warps
self.warp_row_tiles = warp_row_tiles
self.warp_col_tiles = warp_col_tiles
self.chunk = chunk
self._initialize_k_dim(a_dtype)
self._initialize_abbrev(a_dtype, b_dtype, accum_dtype)
self._initialize_local_size(self.M_DIM, self.N_DIM, self.k_dim, self.WARP_SIZE)
self._initialize_mfma_prefix(self.k_dim)
self._initialize_micro_size(self.M_DIM, self.N_DIM, self.k_dim)
self._initialize_k_pack(k_pack)
self._initialize_is_m_first(is_m_first)
super().__init__(
a_dtype=a_dtype,
b_dtype=b_dtype,
accum_dtype=accum_dtype,
a_transposed=a_transposed,
b_transposed=b_transposed,
block_row_warps=block_row_warps,
block_col_warps=block_col_warps,
warp_row_tiles=warp_row_tiles,
warp_col_tiles=warp_col_tiles,
chunk=chunk,
reduce_k=reduce_k,
num_elems_per_byte=num_elems_per_byte,
k_pack=k_pack,
is_m_first=is_m_first,
thread_var=thread_var,
)
self._initialize_preshuffle(a_preshuffle, b_preshuffle)
self.warp_rows = warp_row_tiles // self.micro_size_x
self.warp_cols = warp_col_tiles // self.micro_size_y
self.reduce_k = reduce_k
self.threads = (self.WARP_SIZE * (block_row_warps * block_col_warps) * reduce_k)
self.num_elems_per_byte = num_elems_per_byte
def _initialize_preshuffle(self, a_preshuffle: bool, b_preshuffle: bool):
if a_preshuffle is not None:
self.a_preshuffle = a_preshuffle
......
......@@ -45,6 +45,12 @@ def mma_store_32x8_to_shared_16x16_layout(thread_id, local_id):
return row, col
def mma_store_32x2_to_shared_8x8_layout_fp64(thread_id, local_id):
row = thread_id // 4
col = (thread_id % 4) * 2 + local_id
return row, col
# sr represents spatial + reduction layout
# the first axis is spatial while the second axis is reduction
# mma.sync matrix A layout, if wanna trans, please apply map_indices
......
......@@ -3,13 +3,14 @@ import tilelang.language as T
from typing import Literal, Callable
from tilelang.common import TransformKind
from tvm import DataType
from tvm.tir import PrimExpr, IndexMap, Buffer, Var
from tvm.tir import PrimExpr, IndexMap, Buffer, Var, BufferRegion
from tilelang import tvm as tvm
from tvm.runtime import convert
from .utils import (
mma_store_index_map,
get_ldmatrix_offset,
)
from tilelang.utils import is_fragment
from tilelang.utils import is_fragment, to_buffer_region
from tilelang.intrinsics.mma_layout import (
shared_16x8_to_mma_32x4_layout_sr_a,
shared_16x8_to_mma_32x4_layout_sr_b,
......@@ -40,6 +41,7 @@ class TensorCoreIntrinEmitter:
"float16": "fp16",
"bfloat16": "bf16",
"float32": "fp32",
"float64": "fp64",
"int8": "int8",
"int32": "int32",
"float8_e4m3": "e4m3",
......@@ -78,6 +80,11 @@ class TensorCoreIntrinEmitter:
self.warp_col_tiles = warp_col_tiles
self.chunk = chunk
self._initialize_k_dim(a_dtype)
# For FP64, MMA shape is m8n8k4; adjust instance dims early
if DataType(a_dtype).bits == 64:
# Override default M/N dims for fp64 MMA
self.M_DIM = 8
# n_dim will be set to 8 in _initialize_micro_size via k_dim==4
self._initialize_abbrev(a_dtype, b_dtype, accum_dtype)
self._initialize_micro_size(self.M_DIM, self.k_dim)
self._initialize_local_size(self.M_DIM, self.n_dim, self.k_dim, self.WARP_SIZE)
......@@ -105,12 +112,21 @@ class TensorCoreIntrinEmitter:
self.local_size_out = (m_dim * n_dim) // warp_size
def _initialize_abbrev(self, a_dtype, b_dtype, accum_dtype):
self.a_dtype_abbrv = self.dtype_abbrv[a_dtype]
self.b_dtype_abbrv = self.dtype_abbrv[b_dtype]
self.accum_dtype_abbrv = self.dtype_abbrv[accum_dtype]
self.a_dtype_abbrv = self._get_dtype_abbrv(a_dtype)
self.b_dtype_abbrv = self._get_dtype_abbrv(b_dtype)
self.accum_dtype_abbrv = self._get_dtype_abbrv(accum_dtype)
def _get_dtype_abbrv(self, dtype: str) -> str:
try:
return self.dtype_abbrv[dtype]
except KeyError as err:
raise ValueError(f"Unsupported dtype: {dtype}") from err
def _initialize_mma_prefix(self, k_dim: int = 16):
if k_dim == 8:
if k_dim == 4:
# fp64
self.mma_prefix = "m8n8k4"
elif k_dim == 8:
# typically used for tfloat32
self.mma_prefix = "m16n8k8"
elif k_dim == 16:
......@@ -125,22 +141,31 @@ class TensorCoreIntrinEmitter:
def _initialize_micro_size(self, m_dim: int = 16, k_dim: int = 16):
warp_row_tiles = self.warp_row_tiles
warp_col_tiles = self.warp_col_tiles
assert warp_row_tiles >= 16, f"warp_row_tiles must be greater than 16, got {warp_row_tiles}"
assert warp_row_tiles % 16 == 0, f"warp_row_tiles must be divisible by 16, got {warp_row_tiles}"
assert warp_col_tiles >= 8, f"warp_col_tiles must be greater than 8, got {warp_col_tiles}"
assert warp_col_tiles % 8 == 0, f"warp_col_tiles must be divisible by 8, got {warp_col_tiles}"
self.warp_rows = warp_row_tiles // m_dim
if warp_col_tiles % 16 == 0:
self.n_dim = 16
self.micro_size_y = 16
self.warp_cols = warp_col_tiles // 16
else:
# must be divisible by 8
# For fp64 (k_dim==4), micro tile is 8x8, otherwise keep 16x{8|16}
if k_dim == 4:
# fp64 path: m_dim must be 8, n_dim 8
assert m_dim == 8, f"For fp64 MMA, m_dim must be 8, got {m_dim}"
self.n_dim = 8
self.micro_size_y = 8
self.warp_rows = warp_row_tiles // m_dim
self.warp_cols = warp_col_tiles // 8
else:
assert warp_row_tiles >= 16, f"warp_row_tiles must be greater than 16, got {warp_row_tiles}"
assert warp_row_tiles % 16 == 0, f"warp_row_tiles must be divisible by 16, got {warp_row_tiles}"
assert warp_col_tiles >= 8, f"warp_col_tiles must be greater than 8, got {warp_col_tiles}"
assert warp_col_tiles % 8 == 0, f"warp_col_tiles must be divisible by 8, got {warp_col_tiles}"
self.warp_rows = warp_row_tiles // m_dim
if warp_col_tiles % 16 == 0:
self.n_dim = 16
self.micro_size_y = 16
self.warp_cols = warp_col_tiles // 16
else:
# must be divisible by 8
self.n_dim = 8
self.micro_size_y = 8
self.warp_cols = warp_col_tiles // 8
self.micro_size_x = m_dim
self.micro_size_k = k_dim
......@@ -158,8 +183,12 @@ class TensorCoreIntrinEmitter:
return self.thread_var
def get_store_index_map(self, inverse: bool = False) -> IndexMap:
from .utils import mma_store_index_map, mma_store_index_map_fp64
warp_size, local_size_c = self.WARP_SIZE, self.local_size_out
index_map = IndexMap.from_func(mma_store_index_map, index_dtype="int32")
if DataType(self.accum_dtype).bits == 64:
index_map = IndexMap.from_func(mma_store_index_map_fp64, index_dtype="int32")
else:
index_map = IndexMap.from_func(mma_store_index_map, index_dtype="int32")
if not inverse:
return index_map
inverse_index_map = index_map.inverse([warp_size, local_size_c])
......@@ -199,9 +228,47 @@ class TensorCoreIntrinEmitter:
def ldmatrix_a(self,
A_local_buf: Buffer,
A_shared_buf: Buffer,
A_shared_buf: Buffer | BufferRegion,
ki: PrimExpr,
rk: PrimExpr | None = 0):
# Fast path for fp64: no ldmatrix support, do direct per-lane loads
if DataType(self.a_dtype).bits == 64:
warp_row_tiles = self.warp_row_tiles
warp_rows = self.warp_rows
chunk = self.chunk
micro_size_x = self.micro_size_x # 8
micro_size_k = self.micro_size_k # 4
local_size_a = self.local_size_a # 1
a_transposed = self.a_transposed
thread_binding = self.get_thread_binding()
# legalize shared buffer to region
A_region = to_buffer_region(A_shared_buf)
A_buf = A_region.buffer
A_base0 = A_region.region[-2].min
A_base1 = A_region.region[-1].min
@T.macro
def _warp_ld_a_fp64(
A_local_buf,
A_shared_buf,
ki,
thread_binding,
rk=0,
):
tx, _, warp_m = self.extract_thread_binding(thread_binding)
for i in T.serial(warp_rows):
wi = warp_m * warp_row_tiles + i * micro_size_x
wk = rk * chunk + ki * micro_size_k
mi = tx // micro_size_k
mk = tx % micro_size_k
if a_transposed:
A_local_buf[i * local_size_a] = A_buf[A_base0 + wk + mk, A_base1 + wi + mi]
else:
A_local_buf[i * local_size_a] = A_buf[A_base0 + wi + mi, A_base1 + wk + mk]
return _warp_ld_a_fp64(A_local_buf, A_region, ki, thread_binding, rk)
warp_row_tiles = self.warp_row_tiles
warp_rows = self.warp_rows
chunk = self.chunk
......@@ -226,6 +293,13 @@ class TensorCoreIntrinEmitter:
thread_binding = self.get_thread_binding()
# legalize shared buffer to region
A_region = to_buffer_region(A_shared_buf)
A_buf = A_region.buffer
A_base0 = A_region.region[-2].min
A_base1 = A_region.region[-1].min
A_stride_last = A_buf.shape[-1]
@T.macro
def _warp_ldmatrix_a(
A_local_buf,
......@@ -234,14 +308,16 @@ class TensorCoreIntrinEmitter:
thread_binding,
rk=0,
):
stride = A_shared_buf.shape[-1]
stride = A_stride_last
tx, _, warp_m = self.extract_thread_binding(thread_binding)
trans = self.a_transposed
for i in T.serial(warp_rows):
# Assign A_shared_buf_elem
wi, wk = warp_m * warp_row_tiles + i * micro_size_x, rk * chunk + ki * micro_size_k
A_shared_buf_elem = A_shared_buf[wk, wi] if a_transposed else A_shared_buf[wi, wk]
A_shared_buf_elem = A_buf[A_base0 + wk,
A_base1 + wi] if a_transposed else A_buf[A_base0 + wi,
A_base1 + wk]
if ldmatrix_available:
T.ptx_ldmatrix(
......@@ -257,15 +333,59 @@ class TensorCoreIntrinEmitter:
else:
for j in T.serial(local_size_a):
mi, mk = mma_load_layout(tx, j)
A_local_buf[i * local_size_a + j] = A_shared_buf[wk + mk, wi + mi]
if a_transposed:
A_local_buf[i * local_size_a + j] = A_buf[A_base0 + wk + mk,
A_base1 + wi + mi]
else:
A_local_buf[i * local_size_a + j] = A_buf[A_base0 + wi + mi,
A_base1 + wk + mk]
return _warp_ldmatrix_a(A_local_buf, A_shared_buf, ki, thread_binding, rk)
return _warp_ldmatrix_a(A_local_buf, A_region, ki, thread_binding, rk)
def ldmatrix_b(self,
B_local_buf: Buffer,
B_shared_buf: Buffer,
B_shared_buf: Buffer | BufferRegion,
ki: PrimExpr,
rk: PrimExpr | None = 0):
# Fast path for fp64: no ldmatrix support, do direct per-lane loads
if DataType(self.b_dtype).bits == 64:
warp_col_tiles = self.warp_col_tiles
warp_cols = self.warp_cols
chunk = self.chunk
micro_size_y = self.micro_size_y # 8
micro_size_k = self.micro_size_k # 4
local_size_b = self.local_size_b # 1
b_transposed = self.b_transposed
thread_binding = self.get_thread_binding()
# legalize shared buffer to region
B_region = to_buffer_region(B_shared_buf)
B_buf = B_region.buffer
B_base0 = B_region.region[-2].min
B_base1 = B_region.region[-1].min
@T.macro
def _warp_ld_b_fp64(
B_local_buf,
B_shared_buf,
ki,
thread_binding,
rk=0,
):
tx, warp_n, _ = self.extract_thread_binding(thread_binding)
for j in T.serial(warp_cols):
wi = warp_n * warp_col_tiles + j * micro_size_y
wk = rk * chunk + ki * micro_size_k
mi = tx // micro_size_k
mk = tx % micro_size_k
if b_transposed:
B_local_buf[j * local_size_b] = B_buf[B_base0 + wi + mi, B_base1 + wk + mk]
else:
B_local_buf[j * local_size_b] = B_buf[B_base0 + wk + mk, B_base1 + wi + mi]
return _warp_ld_b_fp64(B_local_buf, B_region, ki, thread_binding, rk)
warp_col_tiles = self.warp_col_tiles
warp_cols = self.warp_cols
chunk = self.chunk
......@@ -275,6 +395,13 @@ class TensorCoreIntrinEmitter:
b_dtype = self.b_dtype
b_transposed = self.b_transposed
thread_binding = self.get_thread_binding()
# legalize shared buffer to region
B_region = to_buffer_region(B_shared_buf)
B_buf = B_region.buffer
B_base0 = B_region.region[-2].min
B_base1 = B_region.region[-1].min
B_stride_last = B_buf.shape[-1]
replicate_b = (self.n_dim == 16)
# ldmatrix cannot be used for int8 + trans case.
ldmatrix_available = not (DataType(b_dtype).bits != 16 and not b_transposed)
......@@ -298,7 +425,7 @@ class TensorCoreIntrinEmitter:
thread_binding,
rk=0,
):
stride = B_shared_buf.shape[-1]
stride = B_stride_last
tx, warp_n, _ = self.extract_thread_binding(thread_binding)
trans = not b_transposed
......@@ -310,8 +437,9 @@ class TensorCoreIntrinEmitter:
)
if ldmatrix_available:
B_shared_buf_elem = B_shared_buf[wi, wk] if b_transposed else B_shared_buf[wk,
wi]
B_shared_buf_elem = B_buf[B_base0 + wi,
B_base1 + wk] if b_transposed else B_buf[B_base0 + wk,
B_base1 + wi]
T.ptx_ldmatrix(
b_dtype,
......@@ -329,7 +457,12 @@ class TensorCoreIntrinEmitter:
# must be transposed.
for j in T.serial(local_size_b):
mi, mk = mma_load_layout(tx, j)
B_local_buf[i * local_size_b + j] = B_shared_buf[wk + mk, wi + mi]
if b_transposed:
B_local_buf[i * local_size_b + j] = B_buf[B_base0 + wi + mi,
B_base1 + wk + mk]
else:
B_local_buf[i * local_size_b + j] = B_buf[B_base0 + wk + mk,
B_base1 + wi + mi]
return _warp_ldmatrix_b(B_local_buf, B_shared_buf, ki, thread_binding, rk)
......@@ -617,8 +750,10 @@ class TensorCoreIntrinEmitter:
from tilelang.utils import is_fragment
shape = local_buf.shape
assert is_fragment(
local_buf), f"local_buf {local_buf} must be a fragment, but got {local_buf.scope()}"
inverse_mma_store_layout = self.get_store_index_map(inverse=True)
assert is_fragment(local_buf), "local_buf must be a fragment"
micro_size_x, micro_size_y = self.micro_size_x, self.micro_size_y
local_size_out = self.local_size_out
block_row_warps, block_col_warps = self.block_row_warps, self.block_col_warps
......
from __future__ import annotations
def shared_16x4_to_mma_a_32x4_layout(row, col, rep):
tid = (row % 4) + 16 * ((row // 4) % 2) + 4 * (row // 8) + 8 * rep
local_id = col
return tid, local_id
def shared_4x16_to_mma_b_32x4_layout(row, col, rep):
thread_id = row + 8 * col // 4 + 4 * rep
local_id = col % 4
return thread_id, local_id
def shared_16x4_to_mma_b_32x4_layout_trans(row, col, rep):
thread_id = row % 4 + 4 * rep + 8 * ((row % 8) // 4) + 16 * (row // 8)
local_id = col
return thread_id, local_id
def mma_32x8_to_shared_16x16_layout_fp32(thread_id, local_id):
row = (thread_id % 2) + (
(local_id // 2 % 2) * 2) + 4 * (thread_id // 16) + (thread_id % 16 // 4) % 2 * 8
col = (thread_id % 4 // 2) * 2 + (thread_id % 16 // 8) * 4 + (local_id %
2) + (local_id // 4) * 8
return row, col
def mma_32x8_to_shared_16x16_layout_fp16(thread_id, local_id):
row = (thread_id % 4) + (thread_id // 16) * 4 + (thread_id % 8) // 4 * 8
col = local_id % 4 + ((thread_id % 16) // 8) * 4 + (local_id // 4) * 8
return row, col
def mma_load_a_32x4_to_shared_16x4_layout(thread_id, local_id):
row = (thread_id % 4) + (4 * (((thread_id // 16 + thread_id % 16 // 4 * 2)) % 4))
col = local_id
return row, col
def mma_load_b_32x4_to_shared_16x4_layout_trans(thread_id, local_id):
row = (thread_id % 4) + 8 * (thread_id // 16) + 4 * ((thread_id // 8) % 2)
col = local_id
return row, col
def mma_load_b_32x4_to_shared_4x16_layout(thread_id, local_id):
row = thread_id % 4
col = local_id + (4 * (thread_id // 8))
return row, col
from __future__ import annotations
import tilelang.language as T
from typing import Literal, Callable
from tvm import DataType
from tvm.tir import PrimExpr, IndexMap, Buffer, Var, BufferRegion
from tilelang import tvm as tvm
from tvm.runtime import convert
from tilelang.utils import is_fragment, to_buffer_region
from tilelang.intrinsics.mma_sm70_layout import (
shared_16x4_to_mma_a_32x4_layout,
shared_4x16_to_mma_b_32x4_layout,
shared_16x4_to_mma_b_32x4_layout_trans,
mma_32x8_to_shared_16x16_layout_fp32,
mma_32x8_to_shared_16x16_layout_fp16,
mma_load_a_32x4_to_shared_16x4_layout,
mma_load_b_32x4_to_shared_16x4_layout_trans,
mma_load_b_32x4_to_shared_4x16_layout,
)
lift = convert
class TensorCoreIntrinEmitter:
"""
To eliminate Python syntax within TIR Macro.
"""
M_DIM = 16
# use lowercase as n_dim can be dynamic
# the smallest instructions can be m16n8k16, so the n_dim can also be 8
n_dim = 16
WARP_SIZE = 32
HALF_WARP_SIZE = WARP_SIZE // 2
dtype_abbrv = {
"float16": "fp16",
"bfloat16": "bf16",
"float32": "fp32",
"int8": "int8",
"int32": "int32",
"float8_e4m3": "e4m3",
"float8_e5m2": "e5m2",
}
# Represent the thread binding in the form of (tx, warp_n, warp_m)
is_m_first = False
def __init__(
self,
a_dtype: str = "float16",
b_dtype: str = "float16",
accum_dtype: str = "float16",
a_transposed: bool = False,
b_transposed: bool = False,
block_row_warps: int = 2,
block_col_warps: int = 2,
warp_row_tiles: int = 8,
warp_col_tiles: int = 8,
chunk: int = 16,
reduce_k: int = 1,
num_elems_per_byte: int = 1,
is_m_first: bool | None = False,
thread_var: Var | None = None,
):
self.a_dtype = a_dtype
self.b_dtype = b_dtype
self.accum_dtype = accum_dtype
self.a_transposed = a_transposed
self.b_transposed = b_transposed
# Hint Information
self.block_row_warps = block_row_warps
self.block_col_warps = block_col_warps
self.warp_row_tiles = warp_row_tiles
self.warp_col_tiles = warp_col_tiles
self.chunk = chunk
self._initialize_k_dim(a_dtype)
self._initialize_abbrev(a_dtype, b_dtype, accum_dtype)
self._initialize_micro_size(self.M_DIM, self.k_dim)
self._initialize_local_size(self.M_DIM, self.n_dim, self.k_dim)
self._initialize_mma_prefix(self.k_dim)
self._initialize_is_m_first(is_m_first)
self.reduce_k = reduce_k
self.threads = self.WARP_SIZE * (block_row_warps * block_col_warps) * reduce_k
self.num_elems_per_byte = num_elems_per_byte
self.thread_var = thread_var
if self.warp_rows == 0 or self.warp_cols == 0:
raise ValueError(
f"Invalid threads configuration for this tile shape, {self.warp_rows} x {self.warp_cols} with threads {self.threads}"
)
def _initialize_k_dim(self, a_dtype="float16"):
self.k_dim = 4
def _initialize_local_size(self, m_dim=16, n_dim=16, k_dim=16):
self.local_size_a = (m_dim * k_dim) // self.HALF_WARP_SIZE
self.local_size_b = (n_dim * k_dim) // self.HALF_WARP_SIZE
self.local_size_out = (m_dim * n_dim) // self.WARP_SIZE
def _initialize_abbrev(self, a_dtype, b_dtype, accum_dtype):
self.a_dtype_abbrv = self._get_dtype_abbrv(a_dtype)
self.b_dtype_abbrv = self._get_dtype_abbrv(b_dtype)
self.accum_dtype_abbrv = self._get_dtype_abbrv(accum_dtype)
def _get_dtype_abbrv(self, dtype: str) -> str:
try:
return self.dtype_abbrv[dtype]
except KeyError as err:
raise ValueError(f"Unsupported dtype: {dtype}") from err
def _initialize_mma_prefix(self, k_dim: int = 16):
if k_dim == 4:
# typically used for float16
self.mma_prefix = "m16n16k4"
else:
raise ValueError(f"Unsupported k_dim: {k_dim}")
def _initialize_micro_size(self, m_dim: int = 16, k_dim: int = 16):
warp_row_tiles = self.warp_row_tiles
warp_col_tiles = self.warp_col_tiles
assert warp_row_tiles >= 16, f"warp_row_tiles must be greater than 16, got {warp_row_tiles}"
assert warp_row_tiles % 16 == 0, f"warp_row_tiles must be divisible by 16, got {warp_row_tiles}"
assert warp_col_tiles >= 16, f"warp_col_tiles must be greater than 16, got {warp_col_tiles}"
assert warp_col_tiles % 16 == 0, f"warp_col_tiles must be divisible by 16, got {warp_col_tiles}"
self.warp_rows = warp_row_tiles // m_dim
self.n_dim = 16
self.micro_size_y = 16
self.warp_cols = warp_col_tiles // 16
self.micro_size_x = m_dim
self.micro_size_k = k_dim
def _initialize_is_m_first(self, is_m_first: bool | None = False):
if is_m_first is not None:
self.is_m_first = is_m_first
def get_thread_binding(self):
if self.thread_var is None:
current_frame = T.KernelLaunchFrame.Current()
assert current_frame is not None, "Must be called in a T.Kernel Frame"
return current_frame.get_thread_binding()
else:
return self.thread_var
def get_store_index_map(self, inverse: bool = False) -> IndexMap:
warp_size, local_size_c = self.WARP_SIZE, self.local_size_out
index_map = IndexMap.from_func(
mma_32x8_to_shared_16x16_layout_fp32
if self.accum_dtype == "float32" else mma_32x8_to_shared_16x16_layout_fp16,
index_dtype="int32")
if not inverse:
return index_map
inverse_index_map = index_map.inverse([warp_size, local_size_c])
return inverse_index_map
def extract_thread_binding(
self,
thread_id: PrimExpr,
is_m_first: bool | None = None) -> tuple[PrimExpr, PrimExpr, PrimExpr]:
"""
is_m_first: True if the thread binding is in the form of (tx, warp_n, warp_m)
which represents [warp_size, block_row_warps (split n), block_col_warps (split m)]
Otherwise, it is in the form of [warp_size, block_col_warps (split m), block_row_warps (split n)]
"""
WARP_SIZE = self.WARP_SIZE
block_row_warps = self.block_row_warps
block_col_warps = self.block_col_warps
# if is_m_first is None, then use the default value
if is_m_first is None:
is_m_first = self.is_m_first
if is_m_first:
lane_id, warp_n, warp_m = (
thread_id % WARP_SIZE,
(thread_id // WARP_SIZE) % block_col_warps,
(thread_id // (WARP_SIZE * block_col_warps)) % block_row_warps,
)
return lane_id, warp_n, warp_m
else:
lane_id, warp_m, warp_n = (
thread_id % WARP_SIZE,
(thread_id // WARP_SIZE) % block_row_warps,
(thread_id // (WARP_SIZE * block_row_warps)) % block_col_warps,
)
return lane_id, warp_n, warp_m
def ldmatrix_a(self,
A_local_buf: Buffer,
A_shared_buf: Buffer | BufferRegion,
ki: PrimExpr,
rk: PrimExpr | None = 0):
warp_row_tiles = self.warp_row_tiles
warp_rows = self.warp_rows
chunk = self.chunk
micro_size_x = self.micro_size_x
micro_size_k = self.micro_size_k
local_size_a = self.local_size_a
a_transposed = self.a_transposed
thread_binding = self.get_thread_binding()
assert not a_transposed, "A must be not transposed"
mma_load_layout = mma_load_a_32x4_to_shared_16x4_layout
# legalize shared buffer to region
A_region = to_buffer_region(A_shared_buf)
A_buf = A_region.buffer
A_base0 = A_region.region[-2].min
A_base1 = A_region.region[-1].min
@T.macro
def _warp_ldmatrix_a(
A_local_buf,
A_shared_buf,
ki,
thread_binding,
rk=0,
):
tx, _, warp_m = self.extract_thread_binding(thread_binding)
for i in T.serial(warp_rows):
# Assign A_shared_buf_elem
wi, wk = warp_m * warp_row_tiles + i * micro_size_x, rk * chunk + ki * micro_size_k
for j in T.vectorized(local_size_a):
mi, mk = mma_load_layout(tx, j)
A_local_buf[i * local_size_a + j] = A_buf[A_base0 + wi + mi, A_base1 + wk + mk]
return _warp_ldmatrix_a(A_local_buf, A_region, ki, thread_binding, rk)
def ldmatrix_b(self,
B_local_buf: Buffer,
B_shared_buf: Buffer | BufferRegion,
ki: PrimExpr,
rk: PrimExpr | None = 0):
warp_col_tiles = self.warp_col_tiles
warp_cols = self.warp_cols
chunk = self.chunk
micro_size_y = self.micro_size_y
micro_size_k = self.micro_size_k
local_size_b = self.local_size_b
b_transposed = self.b_transposed
thread_binding = self.get_thread_binding()
mma_load_layout = mma_load_b_32x4_to_shared_16x4_layout_trans if b_transposed else mma_load_b_32x4_to_shared_4x16_layout
# legalize shared buffer to region
B_region = to_buffer_region(B_shared_buf)
B_buf = B_region.buffer
B_base0 = B_region.region[-2].min
B_base1 = B_region.region[-1].min
@T.macro
def _warp_ldmatrix_b(
B_local_buf,
B_shared_buf,
ki,
thread_binding,
rk=0,
):
tx, warp_n, _ = self.extract_thread_binding(thread_binding)
for i in T.serial(warp_cols):
# Assign B_shared_elem
wi, wk = (
warp_n * warp_col_tiles + i * micro_size_y,
rk * chunk + ki * micro_size_k,
)
# load 16x32 data from shared buffer to local buffer
# must be transposed.
for j in T.vectorized(local_size_b):
if b_transposed:
mi, mk = mma_load_layout(tx, j)
B_local_buf[i * local_size_b + j] = B_buf[B_base0 + wi + mi,
B_base1 + wk + mk]
else:
mk, mi = mma_load_layout(tx, j)
B_local_buf[i * local_size_b + j] = B_buf[B_base0 + wk + mk,
B_base1 + wi + mi]
return _warp_ldmatrix_b(B_local_buf, B_region, ki, thread_binding, rk)
def mma(self,
A_local_buf: Buffer,
B_local_buf: Buffer,
C_local_buf: Buffer,
k_inner: PrimExpr | None = 0):
warp_rows = self.warp_rows
warp_cols = self.warp_cols
local_size_a = self.local_size_a
local_size_b = self.local_size_b
local_size_out = self.local_size_out
a_dtype_abbrv = self.a_dtype_abbrv
b_dtype_abbrv = self.b_dtype_abbrv
accum_dtype_abbrv = self.accum_dtype_abbrv
mma_prefix = self.mma_prefix
a_is_fragment = is_fragment(A_local_buf)
b_is_fragment = is_fragment(B_local_buf)
a_local_stride: PrimExpr = k_inner * warp_rows * local_size_a if a_is_fragment else 0
b_local_stride: PrimExpr = k_inner * warp_cols * local_size_b if b_is_fragment else 0
a_major = "col" if self.a_transposed else "row"
b_major = "col" if self.b_transposed else "row"
@T.macro
def _warp_mma(A_local_buf, B_local_buf, C_local_buf):
for i, j in T.grid(warp_rows, warp_cols):
T.ptx_mma_sm70(
mma_prefix,
a_major,
b_major,
a_dtype_abbrv,
b_dtype_abbrv,
accum_dtype_abbrv,
A_local_buf.data,
a_local_stride + i * local_size_a,
B_local_buf.data,
b_local_stride + j * local_size_b,
C_local_buf.data,
i * warp_cols * local_size_out + j * local_size_out,
)
return _warp_mma(A_local_buf, B_local_buf, C_local_buf)
def make_mma_load_layout(self,
local_buf: Buffer,
matrix: Literal["A", "B"] = "A") -> T.Fragment:
"""
Create a layout function for storing MMA results into a fragment buffer.
This layout is used in conjunction with `inverse_mma_store_layout` to
map fragment indices to threads and local indices.
Parameters
----------
local_buf : tir.Buffer
The local buffer representing a fragment of a matrix.
Returns
-------
T.Fragment
A fragment object that describes how threads and indices
in `local_buf` are laid out.
Raises
------
AssertionError
If `local_buf` is not detected to be a fragment buffer.
"""
from tilelang.utils import is_fragment
assert matrix in ["A", "B"], "matrix should be either A or B"
matrix_is_a: bool = matrix == "A"
matrix_is_b: bool = matrix == "B"
dtype = self.a_dtype if matrix_is_a else self.b_dtype
dtype_bits = DataType(dtype).bits
transposed = self.a_transposed if matrix_is_a else self.b_transposed
# s represents spatial axis
# r represents reduction axis
# sr represents the two dims are spatial + reduction
# rs represents the two dims are reduction + spatial
# sr also can represent a non-transposed basic layout
# then rs also can represent a transposed basic layout
transform_func_sr_a: Callable = None
transform_func_sr_b: Callable = None
transform_func_rs_b: Callable = None
if dtype_bits == 16:
transform_func_sr_a = shared_16x4_to_mma_a_32x4_layout
transform_func_sr_b = shared_16x4_to_mma_b_32x4_layout_trans
transform_func_rs_b = shared_4x16_to_mma_b_32x4_layout
else:
raise ValueError(f"Unsupported dtype {dtype}")
is_sr_conditions = [False]
is_sr_conditions.append(matrix_is_a and not transposed)
is_sr_conditions.append(matrix_is_b and transposed)
is_sr_axis_order = any(is_sr_conditions)
# the layout of mma.sync is row.col.
# so the b matrix expected a transposed basic layout
transform_func: Callable = None
if matrix_is_a:
transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a(
j, i)
elif matrix_is_b:
transform_func = transform_func_sr_b if is_sr_axis_order else lambda i, j: transform_func_rs_b(
i, j)
else:
raise ValueError(f"Unsupported matrix {matrix}")
assert is_fragment(local_buf), f"local_buf must be a fragment, but got {local_buf.scope()}"
if matrix_is_a:
micro_size_s, micro_size_r = self.micro_size_x, self.micro_size_k
else:
micro_size_r, micro_size_s = self.micro_size_k, self.micro_size_y
block_row_warps, block_col_warps = (
self.block_row_warps,
self.block_col_warps,
)
inverse_mma_load_layout = IndexMap.from_func(transform_func, index_dtype="int32")
def forward(i: int, j: int, rep: int) -> int:
"""
Given the row index `i` and column index `j` in the fragment,
"""
lane_id, local_id = inverse_mma_load_layout.map_indices([i, j, rep])
return lane_id, local_id
base_fragment = T.Fragment(
[micro_size_s, micro_size_r] if is_sr_axis_order else [micro_size_r, micro_size_s],
forward_fn=forward,
replicate=2)
warp_rows, warp_cols = self.warp_rows, self.warp_cols
chunk = self.chunk
warp_s = warp_rows if matrix_is_a else warp_cols
warp_r = chunk // micro_size_r
block_s = block_row_warps if matrix_is_a else block_col_warps
replicate = block_col_warps if matrix_is_a else block_row_warps
if is_sr_axis_order:
warp_fragment = base_fragment.repeat([warp_s, warp_r],
repeat_on_thread=False,
lower_dim_first=False)
if matrix_is_a:
block_fragment = warp_fragment.repeat([block_s, 1],
repeat_on_thread=True,
lower_dim_first=True).replicate(replicate)
elif matrix_is_b:
block_fragment = warp_fragment.replicate(replicate).repeat([block_s, 1],
repeat_on_thread=True,
lower_dim_first=True)
else:
raise ValueError(f"Unsupported matrix type {matrix}")
else:
warp_fragment = base_fragment.repeat([warp_r, warp_s],
repeat_on_thread=False,
lower_dim_first=True)
if matrix_is_a:
block_fragment = warp_fragment.repeat([1, block_s],
repeat_on_thread=True,
lower_dim_first=True).replicate(replicate)
elif matrix_is_b:
block_fragment = warp_fragment.replicate(replicate).repeat([1, block_s],
repeat_on_thread=True,
lower_dim_first=True)
else:
raise ValueError(f"Unsupported matrix type {matrix}")
return block_fragment
def make_mma_store_layout(self, local_buf: Buffer) -> T.Fragment:
"""
Create a layout function for storing MMA results into a fragment buffer.
This layout is used in conjunction with `inverse_mma_store_layout` to
map fragment indices to threads and local indices.
Parameters
----------
local_buf : tir.Buffer
The local buffer representing a fragment of a matrix.
Returns
-------
T.Fragment
A fragment object that describes how threads and indices
in `local_buf` are laid out.
Raises
------
AssertionError
If `local_buf` is not detected to be a fragment buffer.
"""
from tilelang.utils import is_fragment
shape = local_buf.shape
inverse_mma_store_layout = self.get_store_index_map(inverse=True)
assert is_fragment(local_buf), "local_buf must be a fragment"
micro_size_x, micro_size_y = self.micro_size_x, self.micro_size_y
local_size_out = self.local_size_out
block_row_warps, block_col_warps = self.block_row_warps, self.block_col_warps
warp_rows, warp_cols = self.warp_rows, self.warp_cols
warp_size = self.WARP_SIZE
is_m_first = self.is_m_first
def forward_thread(i: int, j: int) -> int:
"""
Given the row index `i` and column index `j` in the fragment,
map them to a thread index according to `inverse_mma_store_layout`.
"""
# the upper bounds of i and j are block_row_warps * warp_rows * micro_size_x and block_col_warps * warp_cols * micro_size_y
# the upper bounds of block_row_warps and block_col_warps are warp_rows and warp_cols
block_i, block_j = (i // micro_size_x) // warp_rows, (j // micro_size_y) // warp_cols
# upper bounds of mma_i and mma_j are micro_size_x and micro_size_y
mma_i, mma_j = i % micro_size_x, j % micro_size_y
lane_id, _ = inverse_mma_store_layout.map_indices([mma_i, mma_j])
if is_m_first:
thread_id = block_i * (block_col_warps * warp_cols) + block_j * warp_size + lane_id
else:
thread_id = block_j * (block_row_warps * warp_size) + block_i * warp_size + lane_id
return thread_id
def forward_index(i: int, j: int) -> int:
"""
Given the row index `i` and column index `j` in the fragment,
map them to a local index in a single thread according
to `inverse_mma_store_layout`.
"""
# the upper bounds of i and j are block_row_warps * warp_rows * micro_size_x and block_col_warps * warp_cols * micro_size_y
# the upper bounds of warp_i and warp_j are warp_rows and warp_cols
warp_i, warp_j = (i // micro_size_x) % warp_rows, (j // micro_size_y) % warp_cols
# upper bounds of mma_i and mma_j are micro_size_x and micro_size_y
mma_i, mma_j = i % micro_size_x, j % micro_size_y
_, local_id = inverse_mma_store_layout.map_indices([mma_i, mma_j])
return warp_i * (warp_cols * local_size_out) + warp_j * local_size_out + local_id
return T.Fragment(
shape,
forward_thread_fn=forward_thread,
forward_index_fn=forward_index,
)
from __future__ import annotations
from enum import IntEnum
import tilelang.language as T
from .mma_macro_generator import TensorCoreIntrinEmitter as MMAIntrinEmitter
from tvm import DataType
from tvm.tir import PrimExpr, Buffer, Var, BufferLoad, BufferRegion
from tilelang import tvm as tvm
from tilelang import _ffi_api
from tilelang.utils import is_tensor_memory
from tilelang.layout import (
Layout,
make_full_bank_swizzled_layout,
make_half_bank_swizzled_layout,
make_quarter_bank_swizzled_layout,
make_linear_layout,
)
from tvm.runtime import convert
lift = convert
class SwizzleMode(IntEnum):
# SWIZZLE_NONE = 0, SWIZZLE_32B = 3, SWIZZLE_64B = 2, SWIZZLE_128B = 1
NONE = 0
SWIZZLE_128B = 2
SWIZZLE_64B = 4
SWIZZLE_32B = 6
def is_none(self) -> bool:
return self == SwizzleMode.NONE
def is_swizzle_32b(self) -> bool:
return self == SwizzleMode.SWIZZLE_32B
def is_swizzle_64b(self) -> bool:
return self == SwizzleMode.SWIZZLE_64B
def is_swizzle_128b(self) -> bool:
return self == SwizzleMode.SWIZZLE_128B
def swizzle_byte_size(self) -> int:
if self.is_swizzle_32b():
return 32
elif self.is_swizzle_64b():
return 64
elif self.is_swizzle_128b():
return 128
else:
return 1
def swizzle_atom_size(self) -> int:
if self.is_swizzle_32b():
return 32 // 16
elif self.is_swizzle_64b():
return 64 // 16
elif self.is_swizzle_128b():
return 128 // 16
else:
return 1
# derive from MMAIntrinEmitter as some layouts are the same
class TensorCoreIntrinEmitter(MMAIntrinEmitter):
"""
To eliminate Python syntax within TIR Macro.
"""
# should be rewritten to support dynamic k_dim
tcgen05_prefix: str
a_shared_layout: Layout = None
b_shared_layout: Layout = None
def __init__(
self,
a_dtype: str = "float16",
b_dtype: str = "float16",
accum_dtype: str = "float16",
a_transposed: bool = False,
b_transposed: bool = False,
block_row_warps: int = 2,
block_col_warps: int = 2,
warp_row_tiles: int = 8,
warp_col_tiles: int = 8,
chunk: int = 16,
reduce_k: int = 1,
num_elems_per_byte: int = 1,
is_m_first: bool = False,
thread_var: Var | None = None,
):
super().__init__(a_dtype, b_dtype, accum_dtype, a_transposed, b_transposed, block_row_warps,
block_col_warps, warp_row_tiles, warp_col_tiles, chunk, reduce_k,
num_elems_per_byte, is_m_first, thread_var)
def _assign_a_shared_layout(self, layout: Layout):
self.a_shared_layout = layout
return self
def _assign_b_shared_layout(self, layout: Layout):
self.b_shared_layout = layout
return self
def _initialize_micro_size(self, m_dim: int = 16, k_dim: int = 16):
warp_row_tiles = self.warp_row_tiles
warp_col_tiles = self.warp_col_tiles
# For tcgen05, warp_row_tiles is 8 as we can use .ws to support m32
assert warp_row_tiles >= 8, f"warp_row_tiles must be greater than 8, got {warp_row_tiles}"
assert warp_row_tiles % 8 == 0, f"warp_row_tiles must be divisible by 8, got {warp_row_tiles}"
assert warp_col_tiles >= 8, f"warp_col_tiles must be greater than 8, got {warp_col_tiles}"
assert warp_col_tiles % 8 == 0, f"warp_col_tiles must be divisible by 8, got {warp_col_tiles}"
# four warps per block
self.warp_rows = warp_row_tiles // 8
if warp_col_tiles % 16 == 0:
self.n_dim = 16
self.micro_size_y = 16
self.warp_cols = warp_col_tiles // 16
else:
# must be divisible by 8
self.n_dim = 8
self.micro_size_y = 8
self.warp_cols = warp_col_tiles // 8
self.micro_size_x = m_dim
self.micro_size_k = k_dim
def _determinate_swizzle_mode(self, buffer: Buffer, layout: Layout) -> SwizzleMode:
# same behavior to src/layout/gemm_layouts.cc::makeGemmABLayoutHopper
if layout is None or layout.is_equal(make_linear_layout(buffer)):
return SwizzleMode.NONE
elif layout.is_equal(make_quarter_bank_swizzled_layout(buffer)):
return SwizzleMode.SWIZZLE_32B
elif layout.is_equal(make_half_bank_swizzled_layout(buffer)):
return SwizzleMode.SWIZZLE_64B
elif layout.is_equal(make_full_bank_swizzled_layout(buffer)):
return SwizzleMode.SWIZZLE_128B
else:
raise ValueError(f"Unsupported swizzle mode: {layout}")
def tcgen05mma(self,
A_buf: Buffer,
B_buf: Buffer,
C_local_buf: Buffer,
mbar,
clear_accum: PrimExpr = False):
if is_tensor_memory(A_buf):
return self.tcgen05mma_rs(A_buf, B_buf, C_local_buf, clear_accum)
accum_dtype = self.accum_dtype
m_dim = self.block_row_warps * self.warp_row_tiles
micro_size_k = self.micro_size_k
k_dim, n_dim = self.chunk, self.block_col_warps * self.warp_col_tiles
scale_in_a = 1
scale_in_b = 1
assert k_dim >= micro_size_k, f"k_dim must be greater than or equal to {micro_size_k}, got k_dim: {k_dim}"
a_is_k_major = not self.a_transposed
b_is_k_major = self.b_transposed
a_swizzle_mode = self._determinate_swizzle_mode(A_buf, self.a_shared_layout)
b_swizzle_mode = self._determinate_swizzle_mode(B_buf, self.b_shared_layout)
elems_in_bits = DataType(self.a_dtype).bits
elems_in_bytes = elems_in_bits // 8
a_swizzle_atom_elems = a_swizzle_mode.swizzle_byte_size() // elems_in_bytes
b_swizzle_atom_elems = n_dim if b_swizzle_mode.is_none(
) else b_swizzle_mode.swizzle_byte_size() // elems_in_bytes
accum_dtype_in_bits = DataType(accum_dtype).bits
meta = self.get_tcgen5_mma_meta(m_dim, n_dim, k_dim)
if len(meta) != 3:
raise ValueError(
f"Unsupported TCGEN5MMA configuration for desc generation: M={m_dim}, N={n_dim}, "
f"K={k_dim}, A dtype={self.a_dtype}, accum dtype={self.accum_dtype}")
atom_m, atom_n, atom_k = (int(x) for x in meta)
enable_ws = atom_m != 128
# by default, we utilize non-swizzle layout offset
a_leading_byte_offset = (8 * 8 * elems_in_bytes) if a_is_k_major else (8 * m_dim *
elems_in_bytes)
a_stride_byte_offset = (8 * k_dim * elems_in_bytes) if a_is_k_major else (8 * 8 *
elems_in_bytes)
if not a_swizzle_mode.is_none():
# swizzle mode doesn't require LBO/SBO to be 1
# https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-leading-dimension-byte-offset
if a_is_k_major:
a_leading_byte_offset = 16
a_stride_byte_offset = 8 * a_swizzle_mode.swizzle_byte_size()
else:
# MN Major
# LBO represents the distance between two atoms along the M dimension
# SBO represents the distance between two atoms along the K dimension
a_m_axis_atoms = m_dim // a_swizzle_atom_elems
if a_m_axis_atoms <= 1:
a_leading_byte_offset = 0
else:
a_leading_byte_offset = k_dim * a_swizzle_mode.swizzle_byte_size()
if a_m_axis_atoms <= 1:
a_stride_byte_offset = 8 * elems_in_bytes * m_dim
else:
a_stride_byte_offset = 8 * elems_in_bytes * a_swizzle_atom_elems
b_leading_byte_offset = (8 * 8 * elems_in_bytes) if b_is_k_major else (8 * n_dim *
elems_in_bytes)
b_stride_byte_offset = (8 * k_dim *
elems_in_bytes) if b_is_k_major else (0 if n_dim == 8 else
(8 * 8 * elems_in_bytes))
if not b_swizzle_mode.is_none():
# swizzle mode doesn't require LBO/SBO to be 1
# https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-leading-dimension-byte-offset
if b_is_k_major:
b_leading_byte_offset = 16
b_stride_byte_offset = 8 * b_swizzle_mode.swizzle_byte_size()
else:
# MN Major, K * N
# LBO represents the distance between two atoms along the N dimension
# SBO represents the distance between two atoms along the K dimension
b_n_axis_atoms = n_dim // b_swizzle_atom_elems
if b_n_axis_atoms <= 1:
b_leading_byte_offset = 0
else:
b_leading_byte_offset = 8 * 8 * elems_in_bytes * k_dim
if b_n_axis_atoms <= 1:
b_stride_byte_offset = 8 * elems_in_bytes * n_dim
else:
b_stride_byte_offset = 8 * elems_in_bytes * b_swizzle_atom_elems
# for example, if [n, k] where k is 128, we should split it into 2 atoms
# where max specially handles the case when n_dim is 8.
ak_atom_size = max(a_swizzle_atom_elems // micro_size_k, 1)
bk_atom_size = max(b_swizzle_atom_elems // micro_size_k, 1)
instr_desc = self.get_tcgen5_instr_desc(
atom_m,
atom_n,
atom_k,
a_is_k_major,
b_is_k_major,
scale_in_a,
scale_in_b,
)
# Allocate an instruction descriptor wrapper and initialize it
a_dtype_abbrv = self.a_dtype_abbrv
mask_zero = T.Cast("int32", 0)
mask0 = mask1 = mask2 = mask3 = mask_zero
num_inst_m = 4 * self.warp_row_tiles // atom_m
num_inst_n = self.warp_col_tiles // atom_n
# Helper to allow BufferRegion/BufferLoad as inputs
def access_ptr_from(buffer_or_load_or_region, access_type: str = "r"):
if isinstance(buffer_or_load_or_region, Buffer):
return buffer_or_load_or_region.access_ptr(access_type)
elif isinstance(buffer_or_load_or_region, BufferLoad):
buffer_load = buffer_or_load_or_region
offset, stride = 0, 1
buffer = buffer_load.buffer
for i, shape in enumerate(reversed(buffer.shape)):
indice = buffer_load.indices[len(buffer_load.indices) - i - 1]
if isinstance(indice, (tvm.tir.IntImm, tvm.tir.PrimExpr)):
offset += indice * stride
elif isinstance(indice, tvm.tir.Ramp):
offset += indice.base * stride
else:
raise ValueError(f"Unsupported index type: {type(indice)}")
stride *= shape
return buffer.access_ptr(access_type, offset=offset)
elif isinstance(buffer_or_load_or_region, BufferRegion):
buffer_region = buffer_or_load_or_region
buffer = buffer_region.buffer
offset, stride = 0, 1
for i, shape in enumerate(reversed(buffer.shape)):
offset += buffer_region.region[len(buffer_region.region) - i - 1].min * stride
stride *= shape
return buffer.access_ptr(access_type, offset=offset)
else:
raise ValueError(f"Unsupported buffer type: {type(buffer_or_load_or_region)}")
@T.macro
def _warp_mma(A_buf, B_buf, C_local_buf, mbar):
# Allocate SMEM descriptors for A and B
desc_a = T.alloc_tcgen05_smem_desc()
desc_b = T.alloc_tcgen05_smem_desc()
A_ptr = access_ptr_from(A_buf, "r")
B_ptr = access_ptr_from(B_buf, "r")
T.initialize_tcgen05_descriptor(
desc_a,
A_ptr,
int(a_leading_byte_offset >> 4),
int(a_stride_byte_offset >> 4),
0,
False,
int(a_swizzle_mode),
)
T.initialize_tcgen05_descriptor(
desc_b,
B_ptr,
int(b_leading_byte_offset >> 4),
int(b_stride_byte_offset >> 4),
0,
False,
int(b_swizzle_mode),
)
tmem_col_step = atom_n // (128 // atom_m)
for j in T.unroll(num_inst_n):
for i in T.unroll(num_inst_m):
for ki in T.unroll(0, (k_dim // micro_size_k)):
scale_out = T.Select(ki != 0, 1, T.Select(clear_accum, 0, 1))
A_elem_offset = (
ki % ak_atom_size
) * micro_size_k + i * atom_m * a_swizzle_atom_elems + (
ki // ak_atom_size
) * m_dim * a_swizzle_atom_elems if a_is_k_major else i * atom_m * k_dim + ki * a_swizzle_atom_elems * micro_size_k
B_elem_offset = (ki // bk_atom_size) * n_dim * b_swizzle_atom_elems + (
ki % bk_atom_size
) * micro_size_k + j * atom_n * b_swizzle_atom_elems if b_is_k_major else (
ki * b_swizzle_atom_elems * micro_size_k + j * atom_n *
(k_dim if n_dim // b_swizzle_atom_elems > 1 else 1))
A_byte_offset = A_elem_offset * elems_in_bytes
B_byte_offset = B_elem_offset * elems_in_bytes
C_offset = (i * n_dim + j * tmem_col_step
) * accum_dtype_in_bits // 32 # 32 bits per tmem bank
T.ptx_tcgen05_mma_ss(
a_dtype_abbrv,
desc_a.data,
A_byte_offset,
desc_b.data,
B_byte_offset,
C_local_buf.data,
C_offset,
instr_desc,
scale_out,
mask0,
mask1,
mask2,
mask3,
enable_ws,
)
T.tcgen05_mma_arrive(mbar)
return _warp_mma(A_buf, B_buf, C_local_buf, mbar)
def make_mma_load_layout(self, local_buf: Buffer, matrix: str = "A") -> T.Fragment:
raise NotImplementedError
def make_mma_store_layout(self, tmem_buf: Buffer) -> Layout:
"""
Create the TCGEN5 tensor-memory layout used to store MMA accumulators.
Parameters
----------
tmem_buf : tir.Buffer
The local buffer representing tensormemory of a mma's output
Returns
-------
Layout
Layout object describing how logical (i, j) coordinates map to the
swizzled tensor-memory offsets required by TCGEN5MMA.
Raises
------
AssertionError
If `tmem_buf` is not detected to be a tensor-memory buffer.
"""
assert is_tensor_memory(tmem_buf), "tmem_buf must reside in tensor memory (shared.tmem)"
if len(tmem_buf.shape) != 2:
raise ValueError(
f"TCGEN5MMA expects a 2-D tensor-memory buffer, got shape {tmem_buf.shape}")
m = int(tmem_buf.shape[0])
n = int(tmem_buf.shape[1])
k = int(self.chunk)
meta = self.get_tcgen5_mma_meta(m, n, k)
if len(meta) != 3:
raise ValueError(f"Unsupported TCGEN5MMA configuration: M={m}, N={n}, K={k}, "
f"A dtype={self.a_dtype}, accum dtype={self.accum_dtype}")
atom_m, atom_n, _ = (int(x) for x in meta)
if m % atom_m != 0 or n % atom_n != 0:
raise ValueError(
f"Invalid TCGEN5MMA store layout for shape ({m}, {n}) with atoms ({atom_m}, {atom_n})"
)
def forward(i: PrimExpr, j: PrimExpr):
atom_idx = (i // atom_m) + (j // atom_n) * (m // atom_m)
ai = i % atom_m
aj = j % atom_n
if atom_m == 128:
# Layout D
return [
ai,
aj + atom_idx * atom_n,
]
if atom_m == 64:
# Layout E (.ws variant)
half_atom_n = atom_n // 2
return [
(ai // 32) * 32 + ai % 32 + (aj // half_atom_n) * 64,
(aj % half_atom_n) + atom_idx * half_atom_n,
]
if atom_m == 32:
# Layout G
quarter_atom_n = atom_n // 4
return [
ai % 32 + (aj // quarter_atom_n) * 32,
(aj % quarter_atom_n) + atom_idx * quarter_atom_n,
]
raise ValueError(f"Unsupported TCGEN5 atom_m={atom_m}")
return Layout([m, n], forward)
def get_tcgen5_mma_meta(self, m: int, n: int, k: int):
return _ffi_api.get_tcgen5_mma_meta(
int(m), int(n), int(k), DataType(self.a_dtype), DataType(self.accum_dtype))
def get_tcgen5_instr_desc(self, atom_m: int, atom_n: int, atom_k: int, a_is_k_major: bool,
b_is_k_major: bool, scale_in_a: int, scale_in_b: int) -> PrimExpr:
desc = _ffi_api.get_tcgen5_instr_desc(
atom_m,
atom_n,
atom_k,
DataType(self.a_dtype),
DataType(self.accum_dtype),
a_is_k_major,
b_is_k_major,
scale_in_a,
scale_in_b,
)
return lift(desc)
......@@ -8,6 +8,7 @@ from .mma_layout import (
ldmatrix_32x16_to_shared_16x32_layout_a,
ldmatrix_32x16_to_shared_16x32_layout_b,
mma_store_32x8_to_shared_16x16_layout,
mma_store_32x2_to_shared_8x8_layout_fp64,
)
from .mfma_layout import (thread_id_shared_access_64x4_to_16x16_layout_C_n_m)
......@@ -82,6 +83,10 @@ def mma_store_index_map(thread_id, local_id):
return mma_store_32x8_to_shared_16x16_layout(thread_id, local_id)
def mma_store_index_map_fp64(thread_id, local_id):
return mma_store_32x2_to_shared_8x8_layout_fp64(thread_id, local_id)
def mfma_store_index_map(thread_id, local_id):
return thread_id_shared_access_64x4_to_16x16_layout_C_n_m(thread_id, local_id)
......
......@@ -4,8 +4,9 @@ from enum import IntEnum
from typing import Callable
from .mma_macro_generator import TensorCoreIntrinEmitter as MMAIntrinEmitter
from tvm import DataType
from tvm.tir import PrimExpr, Buffer, Var, IndexMap
from tilelang.utils import is_fragment
from tvm.tir import PrimExpr, Buffer, Var, IndexMap, BufferRegion
from tilelang.utils import is_fragment, retrive_ptr_from_buffer_region, is_full_region
from math import gcd
from tilelang.layout import (
Layout,
make_full_bank_swizzled_layout,
......@@ -70,6 +71,11 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
# should be rewritten to support dynamic k_dim
wgmma_prefix: str
# wgmma instruction M dimension
wgmma_inst_m: int
# wgmma instruction N dimension
wgmma_inst_n: int
a_shared_layout: Layout = None
b_shared_layout: Layout = None
......@@ -104,9 +110,18 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
return self
def _initialize_wgmma_prefix(self, n_dim: int = 16):
inst_m, inst_n = 64, self.block_col_warps * self.warp_col_tiles
inst_m, inst_n = 64, gcd(self.warp_col_tiles, 256)
assert inst_n % 8 == 0, (
f"inst_n must be a multiple of 8, got {inst_n} "
f"(block_col_warps={self.block_col_warps}, warp_col_tiles={self.warp_col_tiles})")
# Validate inst_n: Hopper WGMMA supports n in [8, 256] and multiple of 8
assert 8 <= inst_n <= 256, (
f"inst_n must be within [8, 256], got {inst_n} "
f"(block_col_warps={self.block_col_warps}, warp_col_tiles={self.warp_col_tiles})")
# 256 bits per instruction
inst_k = 256 // DataType(self.a_dtype).bits
self.wgmma_inst_m = inst_m
self.wgmma_inst_n = inst_n
self.wgmma_prefix = f"m{inst_m}n{inst_n}k{inst_k}"
def _initialize_micro_size(self, m_dim: int = 16, k_dim: int = 16):
......@@ -146,13 +161,14 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
raise ValueError(f"Unsupported swizzle mode: {layout}")
def wgmma(self,
A_buf: Buffer,
B_buf: Buffer,
C_local_buf: Buffer,
clear_accum: PrimExpr = False):
A_region: BufferRegion,
B_region: BufferRegion,
C_region: BufferRegion,
clear_accum: PrimExpr = False,
wg_wait: int = 0):
if is_fragment(A_buf):
return self.wgmma_rs(A_buf, B_buf, C_local_buf, clear_accum)
if is_fragment(A_region):
return self.wgmma_rs(A_region, B_region, C_region, clear_accum, wg_wait)
local_size_out = self.local_size_out
a_dtype_abbrv = self.a_dtype_abbrv
......@@ -164,7 +180,6 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
micro_size_k = self.micro_size_k
k_dim, n_dim = self.chunk, self.block_col_warps * self.warp_col_tiles
wgmma_prefix = self.wgmma_prefix
scale_out = not clear_accum
scale_in_a = 1
scale_in_b = 1
......@@ -173,8 +188,8 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
a_is_k_major = not self.a_transposed
b_is_k_major = self.b_transposed
a_swizzle_mode = self._determinate_swizzle_mode(A_buf, self.a_shared_layout)
b_swizzle_mode = self._determinate_swizzle_mode(B_buf, self.b_shared_layout)
a_swizzle_mode = self._determinate_swizzle_mode(A_region, self.a_shared_layout)
b_swizzle_mode = self._determinate_swizzle_mode(B_region, self.b_shared_layout)
elems_in_bits = DataType(self.a_dtype).bits
elems_in_bytes = elems_in_bits // 8
......@@ -182,6 +197,8 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
a_swizzle_atom_elems = a_swizzle_mode.swizzle_byte_size() // elems_in_bytes
b_swizzle_atom_elems = n_dim if b_swizzle_mode.is_none(
) else b_swizzle_mode.swizzle_byte_size() // elems_in_bytes
accum_bits = DataType(accum_dtype).bits
accum_regs = ((m_dim // 64) * warp_cols * local_size_out * accum_bits + 31) // 32
# by default, we utilize non-swizzle layout offset
a_leading_byte_offset = (8 * 8 * elems_in_bytes) if a_is_k_major else (8 * m_dim *
......@@ -240,41 +257,69 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
# where max specially handles the case when n_dim is 8.
ak_atom_size = max(a_swizzle_atom_elems // micro_size_k, 1)
bk_atom_size = max(b_swizzle_atom_elems // micro_size_k, 1)
wgmma_inst_m, wgmma_inst_n = self.wgmma_inst_m, self.wgmma_inst_n
num_inst_m = 4 * self.warp_row_tiles // wgmma_inst_m
num_inst_n = self.warp_col_tiles // wgmma_inst_n
thread_binding = self.get_thread_binding()
A_ptr = retrive_ptr_from_buffer_region(A_region)
B_ptr = retrive_ptr_from_buffer_region(B_region)
assert is_full_region(C_region), "Fragment output C must be a full region"
C_buf = C_region.buffer
@T.macro
def _warp_mma(A_buf, B_buf, C_local_buf):
# TODO(lei): inject warpgroup_fence_operand for C_local_buf
desc_a = T.alloc_descriptor()
desc_b = T.alloc_descriptor()
T.initialize_descriptor(desc_a, A_buf.access_ptr("r"), a_swizzle_mode,
int(a_leading_byte_offset >> 4), int(a_stride_byte_offset >> 4))
T.initialize_descriptor(desc_b, B_buf.access_ptr("r"), b_swizzle_mode,
int(b_leading_byte_offset >> 4), int(b_stride_byte_offset >> 4))
def _warp_mma(A_ptr, B_ptr, C_buf):
tx, warp_n, warp_m = self.extract_thread_binding(thread_binding)
desc_a = T.alloc_wgmma_desc()
desc_b = T.alloc_wgmma_desc()
T.initialize_wgmma_descriptor(desc_a, A_ptr, a_swizzle_mode,
int(a_leading_byte_offset >> 4),
int(a_stride_byte_offset >> 4))
T.initialize_wgmma_descriptor(desc_b, B_ptr, b_swizzle_mode,
int(b_leading_byte_offset >> 4),
int(b_stride_byte_offset >> 4))
T.warpgroup_fence_operand(C_buf, num_regs=accum_regs)
T.warpgroup_arrive()
for ki in T.serial(0, (k_dim // micro_size_k)):
for i in T.serial(m_dim // 64):
A_offset = (ki % ak_atom_size) * micro_size_k + i * 64 * a_swizzle_atom_elems + (
ki // ak_atom_size
) * m_dim * a_swizzle_atom_elems if a_is_k_major else i * 64 * k_dim + ki * a_swizzle_atom_elems * micro_size_k
B_offset = (ki // bk_atom_size) * n_dim * b_swizzle_atom_elems + (
ki % bk_atom_size
) * micro_size_k if b_is_k_major else ki * b_swizzle_atom_elems * micro_size_k
C_offset = i * warp_cols * local_size_out # 4 warps as an unit
T.ptx_wgmma_ss(accum_dtype, wgmma_prefix, a_is_k_major, b_is_k_major,
a_dtype_abbrv, b_dtype_abbrv, accum_dtype_abbrv, desc_a.data,
(A_offset * elems_in_bytes) >> 4, desc_b.data,
(B_offset * elems_in_bytes) >> 4, C_local_buf.data, C_offset,
scale_out, scale_in_a, scale_in_b)
for j in T.unroll(num_inst_n):
for i in T.unroll(num_inst_m):
for ki in T.unroll(k_dim // micro_size_k):
scale_out = T.Select(ki != 0, 1, T.Select(clear_accum, 0, 1))
warp_i = (warp_m // 4) * num_inst_m + i
warp_j = warp_n * num_inst_n + j
A_offset = (
ki % ak_atom_size
) * micro_size_k + warp_i * 64 * a_swizzle_atom_elems + (
ki // ak_atom_size
) * m_dim * a_swizzle_atom_elems if a_is_k_major else warp_i * 64 * k_dim + ki * a_swizzle_atom_elems * micro_size_k
B_offset = (ki // bk_atom_size) * n_dim * b_swizzle_atom_elems + (
ki % bk_atom_size
) * micro_size_k + warp_j * wgmma_inst_n * b_swizzle_atom_elems if b_is_k_major else (
ki * b_swizzle_atom_elems * micro_size_k + warp_j * wgmma_inst_n *
(k_dim if n_dim // b_swizzle_atom_elems > 1 else 1))
C_offset = i * warp_cols * local_size_out + j * warp_cols * local_size_out // num_inst_n # 4 warps as an unit
T.ptx_wgmma_ss(accum_dtype, wgmma_prefix, a_is_k_major, b_is_k_major,
a_dtype_abbrv, b_dtype_abbrv, accum_dtype_abbrv, desc_a.data,
(A_offset * elems_in_bytes) >> 4, desc_b.data,
(B_offset * elems_in_bytes) >> 4, C_buf.data, C_offset,
scale_out, scale_in_a, scale_in_b)
T.warpgroup_commit_batch()
T.warpgroup_wait(0)
if wg_wait >= 0:
T.warpgroup_wait(wg_wait)
T.warpgroup_fence_operand(C_buf, num_regs=accum_regs)
return _warp_mma(A_buf, B_buf, C_local_buf)
return _warp_mma(A_ptr, B_ptr, C_buf)
def wgmma_rs(self,
A_buf: Buffer,
B_buf: Buffer,
C_local_buf: Buffer,
clear_accum: PrimExpr = False):
A_region: BufferRegion,
B_region: BufferRegion,
C_region: BufferRegion,
clear_accum: PrimExpr = False,
wg_wait: int = 0):
local_size_a = self.local_size_a
local_size_out = self.local_size_out
a_dtype_abbrv = self.a_dtype_abbrv
......@@ -286,75 +331,111 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
micro_size_k = self.micro_size_k
k_dim, n_dim = self.chunk, self.block_col_warps * self.warp_col_tiles
wgmma_prefix = self.wgmma_prefix
scale_out = not clear_accum
scale_in_a = 1
scale_in_b = 1
assert k_dim >= micro_size_k, f"k_dim must be greater than or equal to {micro_size_k}, got k_dim: {k_dim}"
elems_in_bytes = DataType(self.a_dtype).bits // 8
a_bits = DataType(self.a_dtype).bits
accum_bits = DataType(accum_dtype).bits
a_regs = ((warp_rows * local_size_a * (k_dim // micro_size_k)) * a_bits + 31) // 32
accum_regs = ((m_dim // 64) * warp_cols * local_size_out * accum_bits + 31) // 32
b_is_k_major = self.b_transposed
b_swizzle_mode = self._determinate_swizzle_mode(B_buf, self.b_shared_layout)
b_swizzle_mode = self._determinate_swizzle_mode(B_region, self.b_shared_layout)
b_swizzle_atom_elems = n_dim if b_swizzle_mode.is_none(
) else b_swizzle_mode.swizzle_byte_size() // elems_in_bytes
b_leading_byte_offset = (8 * 8 * elems_in_bytes) if b_is_k_major else (8 * n_dim *
elems_in_bytes)
b_stride_byte_offset = (8 * k_dim * elems_in_bytes) if b_is_k_major else (8 * 8 *
elems_in_bytes)
b_stride_byte_offset = (8 * k_dim *
elems_in_bytes) if b_is_k_major else (0 if n_dim == 8 else
(8 * 8 * elems_in_bytes))
if not b_swizzle_mode.is_none():
# swizzle mode doesn't require LBO/SBO to be 1
# https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-leading-dimension-byte-offset
if b_is_k_major:
b_leading_byte_offset = 16
b_stride_byte_offset = 8 * b_swizzle_mode.swizzle_byte_size()
else:
# MN Major
# LBO represents the distance between two atoms along the N dimension
# SBO represents the distance between two atoms along the K dimension
b_n_axis_atoms = n_dim // (b_swizzle_mode.swizzle_byte_size() // elems_in_bytes)
b_n_axis_atoms = n_dim // b_swizzle_atom_elems
if b_n_axis_atoms <= 1:
b_leading_byte_offset = 0
else:
b_leading_byte_offset = 8 * b_swizzle_mode.swizzle_atom_size() * (
b_swizzle_mode.swizzle_byte_size() // elems_in_bytes)
b_leading_byte_offset = 8 * 8 * elems_in_bytes * k_dim
if b_n_axis_atoms <= 1:
b_stride_byte_offset = 8 * elems_in_bytes * n_dim
else:
b_stride_byte_offset = 8 * elems_in_bytes * (
b_swizzle_mode.swizzle_byte_size() // elems_in_bytes)
b_stride_byte_offset = 8 * elems_in_bytes * b_swizzle_atom_elems
bk_atom_size = max(b_swizzle_atom_elems // micro_size_k, 1)
wgmma_inst_m, wgmma_inst_n = self.wgmma_inst_m, self.wgmma_inst_n
num_inst_m = 4 * self.warp_row_tiles // wgmma_inst_m
num_inst_n = self.warp_col_tiles // wgmma_inst_n
thread_binding = self.get_thread_binding()
assert is_full_region(A_region), "Fragment input A must be a full region"
assert is_full_region(C_region), "Fragment output C must be a full region"
A_buf = A_region.buffer
B_ptr = retrive_ptr_from_buffer_region(B_region)
C_buf = C_region.buffer
@T.macro
def _warp_mma(A_buf, B_buf, C_local_buf):
desc_b = T.alloc_descriptor()
T.initialize_descriptor(desc_b, B_buf.access_ptr("w"), b_swizzle_mode,
int(b_leading_byte_offset >> 4), int(b_stride_byte_offset >> 4))
for ki in T.serial(0, (k_dim // micro_size_k)):
for i in T.serial(m_dim // 64):
k_dim_offset = ki * micro_size_k
A_offset = ki * warp_rows * local_size_a + i * local_size_a
B_offset = k_dim_offset if b_is_k_major else k_dim_offset * B_buf.shape[-1]
C_offset = i * warp_cols * local_size_out # 4 warps as an unit
T.ptx_wgmma_rs(
accum_dtype,
wgmma_prefix,
self.a_transposed,
not self.b_transposed,
a_dtype_abbrv,
b_dtype_abbrv,
accum_dtype_abbrv,
A_buf.data,
A_offset,
desc_b.data,
(B_offset * elems_in_bytes) >> 4,
C_local_buf.data,
C_offset,
scale_out,
scale_in_a,
scale_in_b,
)
return _warp_mma(A_buf, B_buf, C_local_buf)
def _warp_mma(A_buf, B_ptr, C_buf):
tx, warp_n, warp_m = self.extract_thread_binding(thread_binding)
desc_b = T.alloc_wgmma_desc()
T.initialize_wgmma_descriptor(desc_b, B_ptr, b_swizzle_mode,
int(b_leading_byte_offset >> 4),
int(b_stride_byte_offset >> 4))
T.warpgroup_fence_operand(A_buf, num_regs=a_regs)
T.warpgroup_fence_operand(C_buf, num_regs=accum_regs)
T.warpgroup_arrive()
for j in T.unroll(0, num_inst_n):
for i in T.unroll(num_inst_m):
for ki in T.unroll(0, (k_dim // micro_size_k)):
warp_j = warp_n * num_inst_n + j
scale_out = T.Select(ki != 0, 1, T.Select(clear_accum, 0, 1))
A_offset = ki * warp_rows * local_size_a + i * local_size_a
B_offset = (
ki // bk_atom_size
) * n_dim * b_swizzle_atom_elems + warp_j * wgmma_inst_n * b_swizzle_atom_elems + (
ki % bk_atom_size) * micro_size_k if b_is_k_major else (
ki * b_swizzle_atom_elems * micro_size_k + warp_j * wgmma_inst_n *
(k_dim if n_dim // b_swizzle_atom_elems > 1 else 1))
C_offset = i * warp_cols * local_size_out + j * warp_cols * local_size_out // num_inst_n # 4 warps as an unit
T.ptx_wgmma_rs(
accum_dtype,
wgmma_prefix,
self.b_transposed,
a_dtype_abbrv,
b_dtype_abbrv,
accum_dtype_abbrv,
A_buf.data,
A_offset,
desc_b.data,
(B_offset * elems_in_bytes) >> 4,
C_buf.data,
C_offset,
scale_out,
scale_in_a,
scale_in_b,
)
T.warpgroup_commit_batch()
if wg_wait >= 0:
T.warpgroup_wait(wg_wait)
T.warpgroup_fence_operand(C_buf, num_regs=accum_regs)
T.warpgroup_fence_operand(A_buf, num_regs=a_regs)
return _warp_mma(A_buf, B_ptr, C_buf)
def make_mma_load_layout(self, local_buf: Buffer, matrix: str = "A") -> T.Fragment:
"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment