Commit 57ab687c authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Initialization] Migration of Codebase from Dev Branch into Main (#10)



* Add format.sh script for code formatting and linting

* docs update

* center align the title

* lint fix

* add ignore

* Add .gitignore for 3rdparty directory

* Add requirements-dev.txt, requirements-test.txt, and requirements.txt

* 3rdparty

* Add gemm.h, CMakeLists.txt, _ffi_api.py, __init__.py, runtime.h, reduce.h, loop_partition.h, utils.h, and loop_vectorize.h

* Refactor CMakeLists.txt and include statements

- Update CMakeLists.txt to use a newer version of CMake and add project name
- Remove unnecessary include directories

Fix include paths in layout.cc, codegen.cc, codegen.h, rt_mod.cc, frontend_legalize.cc, inject_pipeline.cc, layout_inference.cc, loop_vectorize.cc, and lower_tile_op.cc

- Update include paths to use relative paths instead of absolute paths

* Update submodule for 3rdparty/tvm

* update

* load dll first

* Refactor CMakeLists.txt and include statements

* Refactor CMakeLists.txt and include statements

* git keep update

* Refactor CMakeLists.txt and include statements

* Refactor CMakeLists.txt and include statements

* refactor code structure

* Update Readme

* CMakeLists Customized

* update readme

* update README

* update readme

* update usage

* with TVM_IMPORT_PYTHON_PATH to handle own tvm build python import

* annotate lower transform global func with `transform` prefix

* Migrate Simplify Pass from tilelang tvm branch

* enhance system environment handling with __init__ and CMake

* Initial commit

* CODE_OF_CONDUCT.md committed

* LICENSE committed

* README.md committed

* SECURITY.md committed

* SUPPORT.md committed

* CODE_OF_CONDUCT Commit

* LICENSE Commit

* SECURITY Commit

* SUPPORT Commit

* Modify Support

* Update README.md

* security ci update

* remove examples

* Update and implement clang-format

* add composable kernel components

* Migrate from latest update

* submodule update

* Test update

* Update License

* Spell check

* lint fix

* add clang-tidy to apply static analysis for c source

* update tilelang examples

* Update Install Docs

* Refactor filetree

* Enhance Install

* conflict resloved

* annotate_version

* Initial Update

* test fix

* install

* Implement setup.py

* lint fix

* Separate Init

* Separate test

* docker file commit

* add logo

* Update Readme and Examples

* update readme

* update logo

* Implement AMD Installation

* Add License

* Update AMD MI300x Benchmark

* update README

* update mi300 benchmark scripts

* update ignore

* enhance build scirpt

* update image

* enhance setup.py to remove duplicated libraries

* remove debug files

* update readme

* update image

* update gemm examples

* update flashattention README

* readme update

* add cmake into requirements

* libinfo fix

* auto update submodule

* lint fix

* Fix AMD Build and Test

* Update check for transpose attribute for CDNA Arch

* typo fix for amd

* Implement Matmul Benchmark

* Refactor Code

* [TypoFix] Fix GEMM Example

* [Docs] Init Linear Attention README

* [TYPO] Typo fix

* [Lint] Lint Fix

* enhance example with intrinsics

* [Enhancement] Improve Buffer Collection during IR Parser

* [Dev] Introduce Current classmethod to get current frame

* submodule update

* fake test pass update

* support thread_extent_api

* code optimize

* Add GEMM function implementation for matrix multiplication

* Update logging format to reflect TileLang in logger messages

* Refactor CMakeLists.txt for improved readability and set default build type to Release

* Support Gemm SS Primitives Implementation

* [README] Upload Tile Language Logo (#5)

* update logo

* Update README.md to enhance formatting and center the title

---------
Co-authored-by: default avatarmicrosoft-github-operations[bot] <55726097+microsoft-github-operations[bot]@users.noreply.github.com>
Co-authored-by: default avatarMicrosoft Open Source <microsoftopensource@users.noreply.github.com>
Co-authored-by: default avatarYu Cheng <yu.cheng@pku.edu.cn>
parent 64f17c2f
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""FFI APIs for tilelang"""
import tvm._ffi
# TVM_REGISTER_GLOBAL("tl.name").set_body_typed(func);
tvm._ffi._init_api("tl", __name__) # pylint: disable=protected-access
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""The auto-tune module for tl programs."""
import tilelang as tl
from tilelang import tvm as tvm
import inspect
from functools import wraps
from typing import Any, Callable, List, Literal
from tqdm import tqdm
import logging
from dataclasses import dataclass
import concurrent.futures
logging.basicConfig(
filename='out.log',
filemode='w',
level=logging.INFO,
format='%(asctime)s %(levelname)s:%(message)s')
@dataclass(frozen=True)
class JITContext:
mod: tl.Profiler
out_idx: List[int]
supply_type: tl.TensorSupplyType
ref_prog: Callable
rtol: float
atol: float
skip_check: bool
profiler: Literal['torch', 'tvm']
target: Literal['cuda', 'hip']
class Autotuner:
def __init__(
self,
fn: Callable,
configs: Any,
keys: List[str],
warmup: int = 25,
rep: int = 100,
timeout: int = 30,
):
self.fn = fn
self.configs = configs
self.keys = keys
self.warmup = warmup
self.rep = rep
self.timeout = timeout
# Precompute cached variables
self.ref_latency_cache = None
self.jit_input_tensors = None
self.ref_input_tensors = None
def run(self, *args: Any, **kwds: Any) -> Any:
sig = inspect.signature(self.fn)
bound_args = sig.bind(*args, **kwds)
bound_args.apply_defaults()
best_latency = 1e8
best_config = None
def target_fn(*new_args, **kwds):
jit_context = self.fn(*new_args, **kwds)
# Unpack the context
mod = jit_context.mod
profiler = jit_context.profiler
skip_check = jit_context.skip_check
ref_prog = jit_context.ref_prog
rtol = jit_context.rtol
atol = jit_context.atol
self.jit_input_tensors = mod._get_inputs(
with_output=profiler ==
"tvm") if self.jit_input_tensors is None else self.jit_input_tensors
if (not skip_check) and (ref_prog is not None):
mod.assert_allclose(ref_prog, rtol=rtol, atol=atol)
latency = mod.do_bench(
mod.func,
n_warmup=self.warmup,
n_repeat=self.rep,
profiler=profiler,
input_tensors=self.jit_input_tensors)
if self.ref_latency_cache is None and ref_prog is not None:
self.ref_input_tensors = mod._get_inputs(
with_output=False) if self.ref_input_tensors is None else self.ref_input_tensors
self.ref_latency_cache = mod.do_bench(
ref_prog,
n_warmup=self.warmup,
n_repeat=self.rep,
profiler="torch",
input_tensors=self.ref_input_tensors)
return latency, self.ref_latency_cache
progress_bar = tqdm(self.configs, desc="Running configurations")
for config in progress_bar:
new_args = []
for name, value in bound_args.arguments.items():
if name not in self.keys:
new_args.append(value)
else:
new_args.append(config[name])
new_args = tuple(new_args)
ref_latency = None
try:
# Use ThreadPoolExecutor to enforce timeout on target_fn execution
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(target_fn, *new_args, **kwds)
latency, ref_latency = future.result(timeout=self.timeout)
except concurrent.futures.TimeoutError:
logging.error(f"Timeout exceeded for config {config}. Skipping this configuration.")
continue
except Exception as e:
logging.error(f"An error occurred while testing config {config}: {e}")
continue
logging.info(f"Config {config} latency: {latency}")
progress_bar.set_postfix({"best_latency": best_latency})
if latency < best_latency:
best_latency = latency
best_config = config
tqdm.write(f"Tuned Latency {latency} with config {config}")
return best_latency, best_config, ref_latency
def __call__(self, *args: Any, **kwds: Any) -> Any:
return self.run(*args, **kwds)
def autotune(configs: Any,
keys: List[str],
warmup: int = 25,
rep: int = 100,
timeout: int = 100) -> Callable:
"""
Decorator for tl program
"""
def decorator(fn: Callable) -> Autotuner:
return Autotuner(fn, configs=configs, keys=keys, warmup=warmup, rep=rep, timeout=timeout)
return decorator
def jit(out_idx: List[int],
supply_type: tl.TensorSupplyType = tl.TensorSupplyType.Normal,
ref_prog: Callable = None,
rtol: float = 1e-2,
atol: float = 1e-2,
skip_check: bool = False,
profiler: Literal['auto', 'torch', 'tvm'] = 'auto',
target: Literal['auto', 'cuda', 'hip'] = 'auto') -> Callable:
def wrapper(fn: Callable):
@wraps(fn)
def decorator(*args, **kwargs) -> float:
# Enabling Efficient Fusion
with tvm.transform.PassContext(config={"tir.merge_static_smem": True}):
mod, params = tl.lower(fn(*args, **kwargs), target=target)
mod = tl.Profiler(mod, params, out_idx, supply_type)
return JITContext(
mod=mod,
out_idx=out_idx,
supply_type=supply_type,
ref_prog=ref_prog,
rtol=rtol,
atol=atol,
skip_check=skip_check,
profiler=profiler,
target=target)
return decorator
return wrapper
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from .transform_kind import TransformKind # noqa: F401
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# Copied from bitblas
from enum import IntEnum
class TransformKind(IntEnum):
NonTransform = 0
InterWarpTransform = 1
IntraWarpTransform = 2
LDMatrixTransform = 3
def is_non_transform(self):
return self == TransformKind.NonTransform
def is_inter_warp_transform(self):
return self == TransformKind.InterWarpTransform
def is_intra_warp_transform(self):
return self == TransformKind.IntraWarpTransform
def is_ld_matrix_transform(self):
return self == TransformKind.LDMatrixTransform
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from .nvcc import compile_cuda # noqa: F401
from .hipcc import compile_hip # noqa: F401
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# pylint: disable=invalid-name
"""Utility to invoke hipcc compiler in the system"""
# File is copied from a modified version of hipcc.py to support
# compilation of HIP code with hipcc compiler
# Source Path:
# https://github1s.com/TileLang/tvm/blob/upstream/python/tvm/contrib/hipcc.py
from __future__ import absolute_import as _abs
import subprocess
import tvm._ffi
from tvm.contrib import utils
from tvm._ffi.base import py_str
from tvm.contrib.rocm import get_rocm_arch, find_rocm_path
def compile_hip(code,
target_format="hsaco",
arch=None,
options=None,
path_target=None,
verbose=False):
"""Compile HIP code with hipcc.
Parameters
----------
code : str
The HIP code.
target_format : str
The target format of hipcc compiler.
arch : str
The AMD GPU architecture.
options : str or list of str
The additional options.
path_target : str, optional
Output file.
Return
------
hsaco : bytearray
The bytearray of the hsaco
"""
if arch is None:
rocm_path = find_rocm_path()
arch = get_rocm_arch(rocm_path)
temp = utils.tempdir()
if target_format not in ["hsaco"]:
raise ValueError("target_format must be hsaco")
temp_code = temp.relpath("my_kernel.cc")
temp_target = temp.relpath("my_kernel.%s" % target_format)
with open(temp_code, "w") as out_file:
out_file.write(code)
file_target = path_target if path_target else temp_target
cmd = ["hipcc"]
cmd += ["-O3", '-c']
if isinstance(arch, str):
cmd += [f"--offload-arch={arch}"]
if target_format == "hsaco":
cmd += ["--genco"]
if options:
if isinstance(options, str):
cmd += [options]
elif isinstance(options, list):
cmd += options
else:
raise ValueError("options must be str or list of str")
cmd += ["-o", file_target]
cmd += [temp_code]
proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
(out, _) = proc.communicate()
if verbose:
print(py_str(out))
if proc.returncode != 0:
msg = code
msg += "\nCompilation error:\n"
msg += py_str(out)
raise RuntimeError(msg)
with open(file_target, "rb") as f:
data = bytearray(f.read())
if not data:
raise RuntimeError("Compilation error: empty result is generated")
return data
@tvm._ffi.register_func("tvm_callback_hip_compile", override=True)
def tvm_callback_hip_compile(code, target):
"""use hipcc to generate fatbin code for better optimization"""
hsaco = compile_hip(code, target_format="hsaco")
return hsaco
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# pylint: disable=invalid-name
# modified from apache tvm python/tvm/contrib/nvcc.py
"""Utility to invoke nvcc compiler in the system"""
from __future__ import absolute_import as _abs
import os
import subprocess
import warnings
import tvm._ffi
from tvm.target import Target
from tvm._ffi.base import py_str
from tvm.contrib import utils
def compile_cuda(code,
target_format="ptx",
arch=None,
options=None,
path_target=None,
verbose=False):
"""Compile cuda code with NVCC from env.
Parameters
----------
code : str
The cuda code.
target_format : str
The target format of nvcc compiler.
arch : str
The cuda architecture.
options : str or list of str
The additional options.
path_target : str, optional
Output file.
Return
------
cubin : bytearray
The bytearray of the cubin
"""
if arch is None:
# If None, then it will use `tvm.target.Target.current().arch`.
# Target arch could be a str like "sm_xx", or a list, such as
# [
# "-gencode", "arch=compute_52,code=sm_52",
# "-gencode", "arch=compute_70,code=sm_70"
# ]
compute_version = "".join(
get_target_compute_version(Target.current(allow_none=True)).split("."))
arch = ["-gencode", f"arch=compute_{compute_version},code=sm_{compute_version}"]
temp = utils.tempdir()
file_name = "tvm_kernels"
if target_format not in ["cubin", "ptx", "fatbin"]:
raise ValueError("target_format must be in cubin, ptx, fatbin")
temp_code = temp.relpath(f"{file_name}.cu")
temp_target = temp.relpath(f"{file_name}.{target_format}")
pass_context = tvm.get_global_func("transform.GetCurrentPassContext")()
kernels_output_dir = (pass_context.config.get("cuda.kernels_output_dir", None))
if kernels_output_dir is not None:
if not os.path.isdir(kernels_output_dir):
os.makedirs(kernels_output_dir)
temp_code = os.path.join(kernels_output_dir, f"{file_name}.cu")
temp_target = os.path.join(kernels_output_dir, f"{file_name}.{target_format}")
with open(temp_code, "w") as out_file:
out_file.write(code)
file_target = path_target if path_target else temp_target
cmd = ["nvcc"]
cmd += [f"--{target_format}", "-O3"]
if kernels_output_dir is not None:
cmd += ["-lineinfo"]
if isinstance(arch, list):
cmd += arch
elif isinstance(arch, str):
cmd += ["-arch", arch]
if options:
if isinstance(options, str):
cmd += [options]
elif isinstance(options, list):
cmd += options
else:
raise ValueError("options must be str or list of str")
cmd += ["-o", file_target]
cmd += [temp_code]
# NOTE: ccbin option can be used to tell nvcc where to find the c++ compiler
# just in case it is not in the path. On Windows it is not in the path by default.
# However, we cannot use TVM_CXX_COMPILER_PATH because the runtime env.
# Because it is hard to do runtime compiler detection, we require nvcc is configured
# correctly by default.
# if cxx_compiler_path != "":
# cmd += ["-ccbin", cxx_compiler_path]
proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
(out, _) = proc.communicate()
if verbose:
print(py_str(out))
if proc.returncode != 0:
msg = code
msg += "\nCompilation error:\n"
msg += py_str(out)
raise RuntimeError(msg)
with open(file_target, "rb") as f:
data = bytearray(f.read())
if not data:
raise RuntimeError("Compilation error: empty result is generated")
return data
def find_cuda_path():
"""Utility function to find cuda path
Returns
-------
path : str
Path to cuda root.
"""
if "CUDA_PATH" in os.environ:
return os.environ["CUDA_PATH"]
cmd = ["which", "nvcc"]
proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
(out, _) = proc.communicate()
out = py_str(out)
if proc.returncode == 0:
return os.path.realpath(os.path.join(str(out).strip(), "../.."))
cuda_path = "/usr/local/cuda"
if os.path.exists(os.path.join(cuda_path, "bin/nvcc")):
return cuda_path
raise RuntimeError("Cannot find cuda path")
def get_cuda_version(cuda_path=None):
"""Utility function to get cuda version
Parameters
----------
cuda_path : Optional[str]
Path to cuda root. If None is passed, will use
`find_cuda_path()` as default.
Returns
-------
version : float
The cuda version
"""
if cuda_path is None:
cuda_path = find_cuda_path()
version_file_path = os.path.join(cuda_path, "version.txt")
if not os.path.exists(version_file_path):
# Debian/Ubuntu repackaged CUDA path
version_file_path = os.path.join(cuda_path, "lib", "cuda", "version.txt")
try:
with open(version_file_path) as f:
version_str = f.read().strip().split()[-1]
return tuple(int(field) for field in version_str.split("."))
except FileNotFoundError:
pass
cmd = [os.path.join(cuda_path, "bin", "nvcc"), "--version"]
proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
(out, _) = proc.communicate()
out = py_str(out)
if proc.returncode == 0:
release_line = [l for l in out.split("\n") if "release" in l][0]
release_fields = [s.strip() for s in release_line.split(",")]
version_str = [f[1:] for f in release_fields if f.startswith("V")][0]
return tuple(int(field) for field in version_str.split("."))
raise RuntimeError("Cannot read cuda version file")
@tvm._ffi.register_func("tvm_callback_cuda_compile", override=True)
def tvm_callback_cuda_compile(code, target): # pylint: disable=unused-argument
"""use nvcc to generate fatbin code for better optimization"""
ptx = compile_cuda(code, target_format="fatbin")
return ptx
@tvm._ffi.register_func("tvm_callback_libdevice_path", override=True)
def find_libdevice_path(arch):
"""Utility function to find libdevice
Parameters
----------
arch : int
The compute architecture in int
Returns
-------
path : str
Path to libdevice.
"""
cuda_path = find_cuda_path()
lib_path = os.path.join(cuda_path, "nvvm/libdevice")
if not os.path.exists(lib_path):
# Debian/Ubuntu repackaged CUDA path
lib_path = os.path.join(cuda_path, "lib/nvidia-cuda-toolkit/libdevice")
selected_ver = 0
selected_path = None
cuda_ver = get_cuda_version(cuda_path)
major_minor = (cuda_ver[0], cuda_ver[1])
if major_minor in (
(9, 0),
(9, 1),
(10, 0),
(10, 1),
(10, 2),
(11, 0),
(11, 1),
(11, 2),
(11, 3),
):
path = os.path.join(lib_path, "libdevice.10.bc")
else:
for fn in os.listdir(lib_path):
if not fn.startswith("libdevice"):
continue
try:
# expected pattern: libdevice.${ARCH}.10.bc
# e.g., libdevice.compute_20.10.bc
ver = int(fn.split(".")[-3].split("_")[-1])
if selected_ver < ver <= arch:
selected_ver = ver
selected_path = fn
except ValueError:
# it can just be `libdevice.10.bc` in CUDA 10
selected_path = fn
if selected_path is None:
raise RuntimeError(f"Cannot find libdevice for arch {arch}")
path = os.path.join(lib_path, selected_path)
return path
def callback_libdevice_path(arch):
try:
return find_libdevice_path(arch)
except RuntimeError:
warnings.warn("Cannot find libdevice path", stacklevel=2)
return ""
@tvm._ffi.register_func("tvm.contrib.nvcc.get_compute_version", override=True)
def get_target_compute_version(target=None):
"""Utility function to get compute capability of compilation target.
Looks for the target arch in three different places, first in the target input, then the
Target.current() scope, and finally the GPU device (if it exists).
Parameters
----------
target : tvm.target.Target, optional
The compilation target
Returns
-------
compute_version : str
compute capability of a GPU (e.g. "8.6")
"""
# 1. input target object
# 2. Target.current()
target = target or Target.current()
if target and target.arch:
arch = target.arch.split("_")[1]
if len(arch) == 2:
major, minor = arch
return major + "." + minor
elif len(arch) == 3:
# This is for arch like "sm_90a"
major, minor, suffix = arch
return major + "." + minor + "." + suffix
# 3. GPU compute version
if tvm.cuda(0).exist:
return tvm.cuda(0).compute_version
raise ValueError("No CUDA architecture was specified or GPU detected."
"Try specifying it by adding '-arch=sm_xx' to your target.")
def parse_compute_version(compute_version):
"""Parse compute capability string to divide major and minor version
Parameters
----------
compute_version : str
compute capability of a GPU (e.g. "6.0")
Returns
-------
major : int
major version number
minor : int
minor version number
"""
split_ver = compute_version.split(".")
try:
major = int(split_ver[0])
minor = int(split_ver[1])
return major, minor
except (IndexError, ValueError) as err:
# pylint: disable=raise-missing-from
raise RuntimeError("Compute version parsing error") from err
def have_fp16(compute_version):
"""Either fp16 support is provided in the compute capability or not
Parameters
----------
compute_version: str
compute capability of a GPU (e.g. "6.0")
"""
major, minor = parse_compute_version(compute_version)
# fp 16 support in reference to:
# https://docs.nvidia.com/cuda/cuda-c-programming-guide/#arithmetic-instructions
conditions = [False]
conditions.append(major == 5 and minor >= 3)
conditions.append(major >= 6)
return any(conditions)
def have_int8(compute_version):
"""Either int8 support is provided in the compute capability or not
Parameters
----------
compute_version : str
compute capability of a GPU (e.g. "6.1")
"""
major, _ = parse_compute_version(compute_version)
return major >= 6
def have_tensorcore(compute_version=None, target=None):
"""Either TensorCore support is provided in the compute capability or not
Parameters
----------
compute_version : str, optional
compute capability of a GPU (e.g. "7.0").
target : tvm.target.Target, optional
The compilation target, will be used to determine arch if compute_version
isn't specified.
"""
if compute_version is None:
if tvm.cuda(0).exist:
compute_version = tvm.cuda(0).compute_version
else:
if target is None or "arch" not in target.attrs:
warnings.warn(
"Tensorcore will be disabled due to no CUDA architecture specified."
"Try specifying it by adding '-arch=sm_xx' to your target.",
stacklevel=2)
return False
compute_version = target.attrs["arch"]
# Compute version will be in the form "sm_{major}{minor}"
major, minor = compute_version.split("_")[1]
compute_version = major + "." + minor
major, _ = parse_compute_version(compute_version)
return major >= 7
def have_cudagraph():
"""Either CUDA Graph support is provided"""
try:
cuda_ver = get_cuda_version()
return not cuda_ver < (10, 0)
except RuntimeError:
return False
@tvm._ffi.register_func("tvm.contrib.nvcc.supports_bf16", override=True)
def have_bf16(compute_version):
"""Either bf16 support is provided in the compute capability or not
Parameters
----------
compute_version : str
compute capability of a GPU (e.g. "8.0")
"""
major, _ = parse_compute_version(compute_version)
return major >= 8
@tvm._ffi.register_func("tvm.contrib.nvcc.supports_fp8", override=True)
def have_fp8(compute_version):
"""Whether fp8 support is provided in the specified compute capability or not
Parameters
----------
compute_version : str
GPU capability
"""
major, minor = parse_compute_version(compute_version)
# fp8 is supported in Ada Lovelace (8.9) or later architectures.
conditions = [False]
conditions.append(major == 8 and minor >= 9)
conditions.append(major >= 9)
return any(conditions)
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from .lower import lower # noqa: F401
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""The compiler for TL programs."""
import tilelang as tl
import os
import os.path as osp
from typing import Literal, Union
from tilelang import tvm as tvm
from tvm import tir, relay
from tvm.target import Target
from tilelang.contrib import hipcc, nvcc
from tilelang.utils import determine_target
def is_device_call(func: tir.PrimFunc):
return bool(func.attrs and "calling_conv" in func.attrs and func.attrs["calling_conv"] == 2)
def is_host_call(func: tir.PrimFunc):
return not is_device_call(func)
@tvm.register_func("tvm_callback_cuda_compile", override=True)
def tvm_callback_cuda_compile(code, target):
project_root = osp.join(osp.dirname(__file__), "../..")
if "TL_TEMPLATE_PATH" in os.environ:
tl_template_path = os.environ["TL_TEMPLATE_PATH"]
else:
tl_template_path = osp.abspath(osp.join(project_root, "src"))
# TODO(lei): this indeed should be renamed into
# TL_CUTLASS_INCLUDE_PATH in the future
if "TL_CUTLASS_PATH" in os.environ:
cutlass_path = os.environ["TL_CUTLASS_PATH"]
else:
cutlass_path = osp.abspath(osp.join(project_root, "3rdparty/cutlass/include"))
compute_version = "".join(nvcc.get_target_compute_version(target).split("."))
# special handle for Hopper
if compute_version == "90":
arch = ["-arch=sm_90a"]
format = "cubin"
else:
arch = [f"-arch=sm_{compute_version}"]
format = "cubin"
# printing out number of registers
debug_option = "--ptxas-options=--verbose,--register-usage-level=10,--warn-on-local-memory-usage"
ptx = nvcc.compile_cuda(
code,
format,
arch,
options=[
"-std=c++17",
debug_option,
"--use_fast_math",
"-I" + tl_template_path,
"-I" + cutlass_path,
],
verbose=False,
)
return ptx
@tvm.register_func("tvm_callback_hip_compile", override=True)
def tvm_callback_hip_compile(code, target):
project_root = osp.join(osp.dirname(__file__), "../..")
tl_template_path = osp.abspath(osp.join(project_root, "src"))
# TODO(lei): actually this indeed should be renamed into
# TL_COMPOSABLE_KERNEL_INCLUDE_PATH in the future
if "TL_COMPOSABLE_KERNEL_PATH" in os.environ:
ck_path = os.environ["TL_COMPOSABLE_KERNEL_PATH"]
else:
ck_path = osp.abspath(osp.join(project_root, "3rdparty/composable_kernel/include"))
hsaco = hipcc.compile_hip(
code,
target_format="hsaco",
options=[
"-std=c++17",
"-I" + tl_template_path,
"-I" + ck_path,
],
verbose=False,
)
return hsaco
def extrac_params(func: tir.PrimFunc):
buffers = [func.buffer_map[var] for var in func.params]
tensor_types = [relay.TensorType(buffer.shape, buffer.dtype) for buffer in buffers]
return tensor_types
def lower(
func_or_mod: Union[tir.PrimFunc, tvm.IRModule],
target: Union[Literal["auto", "cuda", "hip"], Target] = "auto",
target_host="llvm",
runtime_only=False,
):
# TODO(lei): Append C Source code host generation to the runtime
mod = func_or_mod
if isinstance(func_or_mod, tir.PrimFunc):
func = func_or_mod
params = extrac_params(func) if not runtime_only else None
mod = tvm.IRModule({func.attrs["global_symbol"]: func})
if isinstance(target, str):
target = determine_target(target)
target_host = tvm.target.Target.canon_target(target_host)
target = tvm.target.Target(target, target_host)
mod = tir.transform.BindTarget(target)(mod)
mod = tl.transform.FrontendLegalize()(mod)
mod = tir.transform.Simplify()(mod)
mod = tl.transform.LayoutInference()(mod)
mod = tl.transform.LowerTileOp()(mod)
mod = tl.transform.LegalizeVectorizedLoop()(mod)
mod = tl.transform.LegalizeSafeMemoryAccess()(mod)
# Inject Simplify to remove the duplicated conditions
mod = tir.transform.Simplify()(mod)
# which may be introduced by the LegalizeSafeMemoryAccess
if target.arch == "sm_90":
mod = tl.transform.MultiVersionBuffer()(mod)
mod = tl.transform.WarpSpecialized()(mod)
mod = tl.transform.InjectSoftwarePipeline()(mod)
mod = tir.transform.LowerOpaqueBlock()(mod)
# mod = tl.transform.WarpSpecializedPipeline()(mod)
mod = tl.transform.InjectFenceProxy()(mod)
else:
mod = tir.transform.PlanAndUpdateBufferAllocationLocation()(mod)
mod = tl.transform.PipelinePlanning()(mod)
mod = tl.transform.InjectSoftwarePipeline()(mod)
mod = tir.transform.LowerOpaqueBlock()(mod)
mod = tir.transform.FlattenBuffer()(mod)
mod = tir.transform.NarrowDataType(32)(mod)
mod = tir.transform.Simplify()(mod)
mod = tir.transform.VectorizeLoop()(mod)
mod = tir.transform.StorageRewrite()(mod)
mod = tir.transform.UnrollLoop()(mod)
mod = tir.transform.RenormalizeSplitPattern()(mod)
mod = tir.transform.Simplify()(mod)
mod = tir.transform.RemoveNoOp()(mod)
mod = tir.transform.RewriteUnsafeSelect()(mod)
mod = tir.transform.HoistIfThenElse()(mod)
mod = tir.transform.VerifyMemory()(mod)
mod = tir.transform.AnnotateEntryFunc()(mod)
# TODO(lei): This is a hack to make sure the
# thread level allreduce pass can be applied
# in TL. As Tl only use one thread dimension
# the var binding information will be lost
# in the lowering process with Legalization
# and Simplify pass.
# We can find a way better to create var instead
# of putting the LowerThreadAllreduce before
# the Legalization.
mod = tl.transform.ThreadPartialSync("shared.dyn")(mod)
mod = tir.transform.InferFragment()(mod)
mod = tir.transform.LowerThreadAllreduce()(mod)
mod = tl.transform.LowerHopperIntrin()(mod)
mod = tir.transform.InjectPTXAsyncCopy()(mod)
mod = tir.transform.AnnotateDeviceRegions()(mod)
mod = tir.transform.SplitHostDevice()(mod)
mod = tir.transform.MergeSharedMemoryAllocations()(mod)
mod = tir.transform.ThreadSync("shared")(mod)
mod = tir.transform.ThreadSync("shared.dyn")(mod)
mod = tir.transform.MakePackedAPI()(mod)
mod = tir.transform.LowerDeviceKernelLaunch()(mod)
host_mod = tir.transform.Filter(is_host_call)(mod)
host_mod = tir.transform.BindTarget(target_host)(host_mod)
host_mod = tir.transform.FP8StorageLegalize()(host_mod)
host_mod = tir.transform.BF16StorageLegalize()(host_mod)
host_mod = tir.transform.LowerTVMBuiltin()(host_mod)
host_mod = tir.transform.LowerCustomDatatypes()(host_mod)
host_mod = tir.transform.LowerIntrin()(host_mod)
host_mod = tir.transform.LowerDeviceStorageAccessInfo()(host_mod)
host_mod = tir.transform.CombineContextCall()(host_mod)
if target_host.kind.name == "llvm":
host_mod = tvm._ffi.get_global_func("target.build.llvm")(host_mod, target_host)
else:
raise ValueError("Target host is not supported")
device_mod = tir.transform.Filter(is_device_call)(mod)
device_mod = tir.transform.LowerDeviceStorageAccessInfo()(device_mod)
device_mod = tir.transform.LowerIntrin()(device_mod)
device_mod = tir.transform.Simplify()(device_mod)
if target.kind.name == "cuda":
# Debug comments to get the code
# code = tvm._ffi.get_global_func("target.build.tl_debug_codegen")(device_mod, target)
device_mod = tvm._ffi.get_global_func("target.build.tilelang_cuda")(device_mod, target)
elif target.kind.name == "hip":
device_mod = tvm._ffi.get_global_func("target.build.tilelang_hip")(device_mod, target)
else:
raise ValueError("Target is not supported")
host_mod.import_module(device_mod)
if runtime_only is True:
return host_mod
else:
return host_mod, params
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from .utils import (
mma_store_index_map, # noqa: F401
get_ldmatrix_offset, # noqa: F401
)
from .mma_macro_generator import (
TensorCoreIntrinEmitter, # noqa: F401
TensorCoreIntrinEmitterWithLadderTransform, # noqa: F401
)
from .mma_layout import get_swizzle_layout # noqa: F401
from .mma_layout import make_mma_swizzle_layout # noqa: F401
from .mfma_layout import make_mfma_swizzle_layout # noqa: F401
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from tvm import DataType
from tvm.runtime import convert
import tilelang.language as T
def shared_16x4_to_local_64x1_layout_A(i, j):
thread_id = (j * 16 + i)
return thread_id, convert(0)
def thread_id_shared_access_64x1_to_16x4_layout_A(thread_id, local_id):
i = thread_id % 16
j = thread_id // 16
return i, j
def shared_4x16_to_local_64x1_layout_B(i, j):
thread_id = (i * 16 + j)
return thread_id, convert(0)
def thread_id_shared_access_64x1_to_4x16_layout_B(thread_id, local_id):
i = thread_id // 16
j = thread_id % 16
return i, j
def shared_16x16_to_local_64x4_layout_C(i, j):
thread_id = j + (i // 4) * 16
local = (i % 4)
return thread_id, local
def shared_16x16_to_ldmatrix_64x4_layout(ind):
i, j = ind[0], ind[1]
thread_id, local_id = shared_16x16_to_local_64x4_layout_C(i, j)
return convert([thread_id, local_id])
def thread_id_shared_access_64x4_to_16x16_layout_A(thread_id, local_id):
i = thread_id % 16
j = (thread_id // 16) * 4 + local_id
return i, j
def shared_16x16_to_local_64x4_layout_A(i, j):
thread_id = i + 16 * (j // 4)
local = (j % 4)
return thread_id, local
def thread_id_shared_access_64x4_to_16x16_layout_B(thread_id, local_id):
i = local_id + (thread_id // 16) * 4
j = thread_id % 16
return i, j
def shared_16x16_to_local_64x4_layout_B(i, j):
thread_id = j + (i // 4) * 16
local = (i % 4)
return thread_id, local
def thread_id_shared_access_64x4_to_16x16_layout_C_m_n(thread_id, local_id):
i = local_id + (thread_id // 16) * 4
j = thread_id % 16
return i, j
def thread_id_shared_access_64x4_to_16x16_layout_C_n_m(thread_id, local_id):
i = thread_id % 16
j = local_id + (thread_id // 16) * 4
return i, j
def thread_id_shared_access_64x8_to_16x32_layout_A(thread_id, local_id):
i = thread_id % 16
j = (thread_id // 16) * 8 + local_id
return i, j
def shared_16x32_to_local_64x8_layout_A(i, j):
thread_id = i + 16 * (j // 8)
local = (j % 8)
return thread_id, local
def thread_id_shared_access_64x8_to_16x32_layout_B(thread_id, local_id):
i = local_id + (thread_id // 16) * 8
j = thread_id % 16
return i, j
def shared_16x32_to_local_64x8_layout_B(i, j):
thread_id = j + (i // 8) * 16
local = (i % 8)
return thread_id, local
def make_mfma_swizzle_layout(shared_buf, vecSize=8):
dtype = shared_buf.dtype
shape = shared_buf.shape
numBanks = 32
bankBitWidth = 32
SIMDWidth = 16
innerDimLength = shape[-1]
typeWidthInBit = DataType(dtype).bits
elemsPerOneBanksRow = (numBanks * bankBitWidth) // typeWidthInBit
perPhase = max(1, elemsPerOneBanksRow // innerDimLength)
maxPhase = min(SIMDWidth // perPhase, innerDimLength // vecSize)
def transform(row, col):
phase = (row // perPhase) % maxPhase
colOffSwizzled = ((col // vecSize) ^ phase) * vecSize
colOffOrdered = col % vecSize
colOff = colOffSwizzled + colOffOrdered
return row, colOff
return T.Layout(shape, transform)
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import tvm.tl.language as T
from typing import Tuple
from tvm import DataType
from tvm.tir import PrimExpr
from tvm.runtime import convert
from typing import Optional
from .utils import (
mfma_store_index_map,)
lift = convert
class MatrixCoreIntrinEmitter(object):
"""
To eliminate Python syntax within TIR Macro.
"""
M_DIM = 16
N_DIM = 16
WARP_SIZE = 64
dtype_abbrv = {
"float16": "fp16",
"bfloat16": "bf16",
"float32": "fp32",
"int8": "int8",
"int32": "int32",
"e4m3_float8": "e4m3",
"e5m2_float8": "e5m2",
}
# k_pack represents the number of elements in a vectorized instruction
# Detail information can be found in the triton documentation
# https://github.com/triton-lang/triton/blob/433037206d8870f0b82a3cd669097001084a29ed/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp#L419
k_pack = 1
# Represent the thread binding in the form of (tx, warp_n, warp_m)
is_m_first = False
def __init__(
self,
a_dtype: str = "float16",
b_dtype: str = "float16",
accum_dtype: str = "float16",
a_transposed: bool = False,
b_transposed: bool = False,
block_row_warps: int = 2,
block_col_warps: int = 2,
warp_row_tiles: int = 8,
warp_col_tiles: int = 8,
chunk: int = 16,
reduce_k: int = 1,
num_elems_per_byte: int = 1,
k_pack: Optional[int] = None,
is_m_first: Optional[bool] = False,
):
self.a_dtype = a_dtype
self.b_dtype = b_dtype
self.accum_dtype = accum_dtype
self.a_transposed = a_transposed
self.b_transposed = b_transposed
# Hint Information
self.block_row_warps = block_row_warps
self.block_col_warps = block_col_warps
self.warp_row_tiles = warp_row_tiles
self.warp_col_tiles = warp_col_tiles
self.chunk = chunk
self._initialize_k_dim(a_dtype)
self._initialize_abbrev(a_dtype, b_dtype, accum_dtype)
self._initialize_local_size(self.M_DIM, self.N_DIM, self.k_dim, self.WARP_SIZE)
self._initialize_mfma_prefix(self.k_dim)
self._initialize_micro_size(self.M_DIM, self.N_DIM, self.k_dim)
self._initialize_k_pack(k_pack)
self._initialize_is_m_first(is_m_first)
self.warp_rows = warp_row_tiles // self.micro_size_x
self.warp_cols = warp_col_tiles // self.micro_size_y
self.reduce_k = reduce_k
self.threads = (self.WARP_SIZE * (block_row_warps * block_col_warps) * reduce_k)
self.num_elems_per_byte = num_elems_per_byte
def _initialize_k_dim(self, a_dtype="float16"):
if isinstance(a_dtype, str):
a_dtype = DataType(a_dtype)
if a_dtype.bits == 32:
self.k_dim = 4
elif a_dtype.bits in [16, 8]:
self.k_dim = 16
else:
raise ValueError(f"Unsupported a_dtype = {a_dtype}")
def _initialize_local_size(self, m_dim=16, n_dim=16, k_dim=16, warp_size=32):
self.local_size_a = (m_dim * k_dim) // warp_size
self.local_size_b = (n_dim * k_dim) // warp_size
self.local_size_out = (m_dim * n_dim) // warp_size
def _initialize_abbrev(self, a_dtype, b_dtype, accum_dtype):
self.a_dtype_abbrv = self.dtype_abbrv[a_dtype]
self.b_dtype_abbrv = self.dtype_abbrv[b_dtype]
self.accum_dtype_abbrv = self.dtype_abbrv[accum_dtype]
def _initialize_mfma_prefix(self, k_dim=16):
in_dtype, out_dtype = self.a_dtype, self.accum_dtype
M_DIM, N_DIM = self.M_DIM, self.N_DIM
out_dtype_abbrv = {
"float16": "f16",
"float32": "f32",
"int8": "i8",
"int32": "i32"
}[out_dtype]
in_dtype_abbrv = {
"float16": "f16",
"float32": "f32",
"int8": "i8",
"int32": "i32"
}[in_dtype]
self.mfma_suffix = f"{out_dtype_abbrv}_{M_DIM}x{N_DIM}x{k_dim}{in_dtype_abbrv}"
def _initialize_micro_size(self, m_dim=16, n_dim=16, k_dim=16):
self.micro_size_x = m_dim
self.micro_size_y = n_dim
self.micro_size_k = k_dim
def _initialize_k_pack(self, k_pack: Optional[int] = None):
if k_pack is not None:
self.k_pack = k_pack
def _initialize_is_m_first(self, is_m_first: Optional[bool] = False):
if is_m_first is not None:
self.is_m_first = is_m_first
def get_ldmatrix_index_map(self, is_b=False):
from .mfma_layout import (
shared_16x4_to_local_64x1_layout_A,
shared_4x16_to_local_64x1_layout_B,
shared_16x16_to_local_64x4_layout_A,
shared_16x16_to_local_64x4_layout_B,
shared_16x32_to_local_64x8_layout_A,
shared_16x32_to_local_64x8_layout_B,
thread_id_shared_access_64x1_to_16x4_layout_A,
thread_id_shared_access_64x1_to_4x16_layout_B,
thread_id_shared_access_64x4_to_16x16_layout_A,
thread_id_shared_access_64x4_to_16x16_layout_B,
thread_id_shared_access_64x8_to_16x32_layout_A,
thread_id_shared_access_64x8_to_16x32_layout_B,
)
k_dim = self.k_dim * self.k_pack
transposed = self.a_transposed if not is_b else self.b_transposed
if k_dim == 4:
index_map = shared_16x4_to_local_64x1_layout_A
reverse_index_map = thread_id_shared_access_64x1_to_16x4_layout_A
if is_b:
index_map = shared_16x4_to_local_64x1_layout_A if transposed else shared_4x16_to_local_64x1_layout_B
reverse_index_map = thread_id_shared_access_64x1_to_16x4_layout_A if transposed else thread_id_shared_access_64x1_to_4x16_layout_B
elif k_dim == 16:
index_map = shared_16x16_to_local_64x4_layout_B if transposed else shared_16x16_to_local_64x4_layout_A
reverse_index_map = thread_id_shared_access_64x4_to_16x16_layout_B if transposed else thread_id_shared_access_64x4_to_16x16_layout_A
if is_b:
index_map = shared_16x16_to_local_64x4_layout_A if transposed else shared_16x16_to_local_64x4_layout_B
reverse_index_map = thread_id_shared_access_64x4_to_16x16_layout_A if transposed else thread_id_shared_access_64x4_to_16x16_layout_B
elif k_dim == 32:
index_map = shared_16x32_to_local_64x8_layout_B if transposed else shared_16x32_to_local_64x8_layout_A
reverse_index_map = thread_id_shared_access_64x8_to_16x32_layout_B if transposed else thread_id_shared_access_64x8_to_16x32_layout_A
if is_b:
index_map = shared_16x32_to_local_64x8_layout_A if transposed else shared_16x32_to_local_64x8_layout_B
reverse_index_map = thread_id_shared_access_64x8_to_16x32_layout_A if transposed else thread_id_shared_access_64x8_to_16x32_layout_B
else:
raise ValueError("k_dim must be 4 or 16 currently")
return index_map, reverse_index_map
def extract_thread_binding(self,
thread_id,
is_m_first=None) -> Tuple[PrimExpr, PrimExpr, PrimExpr]:
'''
is_m_first: True if the thread binding is in the form of (tx, warp_n, warp_m)
which represents [warp_size, block_row_warps (split n), block_col_warps (split m)]
Otherwise, it is in the form of [warp_size, block_col_warps (split m), block_row_warps (split n)]
'''
WARP_SIZE = self.WARP_SIZE
block_row_warps = self.block_row_warps
block_col_warps = self.block_col_warps
# if is_m_first is None, then use the default value
if is_m_first is None:
is_m_first = self.is_m_first
if is_m_first:
lane_id, warp_n, warp_m = thread_id % WARP_SIZE, (
thread_id //
WARP_SIZE) % block_col_warps, (thread_id //
(WARP_SIZE * block_col_warps)) % block_row_warps,
return lane_id, warp_n, warp_m
else:
lane_id, warp_m, warp_n = thread_id % WARP_SIZE, (
thread_id //
WARP_SIZE) % block_row_warps, (thread_id //
(WARP_SIZE * block_row_warps)) % block_col_warps,
return lane_id, warp_n, warp_m
def ldmatrix_a(self, A_local_buf, A_shared_buf, ki, thread_bindings, rk=0):
warp_row_tiles = self.warp_row_tiles
warp_rows = self.warp_rows
chunk = self.chunk
micro_size_x = self.micro_size_x
micro_size_k = self.micro_size_k
local_size_a = self.local_size_a
k_pack = self.k_pack
is_transposed = self.a_transposed
_, reverse_index_map = self.get_ldmatrix_index_map(is_b=False)
@T.macro
def _warp_ldmatrix_a(
A_local_buf,
A_shared_buf,
ki,
thread_bindings,
rk=0,
):
tx, _, warp_m = self.extract_thread_binding(thread_bindings)
if is_transposed:
for i in T.serial(warp_rows):
for local_id in T.vectorized(k_pack * local_size_a):
row, col = T.meta_var(reverse_index_map(tx, local_id))
l, r = (rk * chunk + ki * micro_size_k,
warp_m * warp_row_tiles + i * micro_size_x)
A_local_buf[i * k_pack * local_size_a + local_id] = A_shared_buf[l + row,
r + col]
else:
for i in T.serial(warp_rows):
for local_id in T.vectorized(k_pack * local_size_a):
row, col = T.meta_var(reverse_index_map(tx, local_id))
l, r = (warp_m * warp_row_tiles + i * micro_size_x,
rk * chunk + ki * micro_size_k)
A_local_buf[i * k_pack * local_size_a + local_id] = A_shared_buf[l + row,
r + col]
return _warp_ldmatrix_a(A_local_buf, A_shared_buf, ki, thread_bindings, rk)
def ldmatrix_b(self, B_local_buf, B_shared_buf, ki, thread_bindings, rk=0):
warp_col_tiles = self.warp_col_tiles
warp_cols = self.warp_cols
chunk = self.chunk
micro_size_y = self.micro_size_y
micro_size_k = self.micro_size_k
local_size_b = self.local_size_b
k_pack = self.k_pack
is_transposed = self.b_transposed
_, reverse_index_map = self.get_ldmatrix_index_map(is_b=True)
@T.macro
def _warp_ldmatrix_b(
B_local_buf,
B_shared_buf,
ki,
thread_bindings,
rk=0,
):
tx, warp_n, _ = self.extract_thread_binding(thread_bindings)
if is_transposed:
for j in T.serial(warp_cols):
for local_id in T.vectorized(k_pack * local_size_b):
row, col = T.meta_var(reverse_index_map(tx, local_id))
l, r = (
warp_n * warp_col_tiles + j * micro_size_y,
rk * chunk + ki * micro_size_k,
)
B_local_buf[j * k_pack * local_size_b + local_id] = B_shared_buf[l + row,
r + col]
else:
for j in T.serial(warp_cols):
for local_id in T.vectorized(k_pack * local_size_b):
row, col = T.meta_var(reverse_index_map(tx, local_id))
l, r = (
rk * chunk + ki * micro_size_k,
warp_n * warp_col_tiles + j * micro_size_y,
)
B_local_buf[j * k_pack * local_size_b + local_id] = B_shared_buf[l + row,
r + col]
return _warp_ldmatrix_b(B_local_buf, B_shared_buf, ki, thread_bindings, rk)
def mfma(self, A_local_buf, B_local_buf, C_local_buf):
warp_rows = self.warp_rows
warp_cols = self.warp_cols
local_size_a = self.local_size_a
local_size_b = self.local_size_b
local_size_out = self.local_size_out
k_pack = self.k_pack
mfma_suffix = self.mfma_suffix
a_dtype, b_dtype, out_dtype = self.a_dtype, self.b_dtype, self.accum_dtype
compute_a_dtype = a_dtype if local_size_a == 1 else f"{a_dtype}x{local_size_a}"
compute_b_dtype = b_dtype if local_size_b == 1 else f"{b_dtype}x{local_size_b}"
compute_out_dtype = out_dtype if local_size_out == 1 else f"{out_dtype}x{local_size_out}"
@T.macro
def _warp_mma(A_local_buf, B_local_buf, C_local_buf):
for kp, i, j in T.grid(k_pack, warp_rows, warp_cols):
T.tvm_mfma(
mfma_suffix,
"row",
"row",
compute_a_dtype,
compute_b_dtype,
compute_out_dtype,
B_local_buf.data,
((j * k_pack + kp) * local_size_b) // local_size_b,
A_local_buf.data,
((i * k_pack + kp) * local_size_a) // local_size_a,
C_local_buf.data,
(i * warp_cols * local_size_out + j * local_size_out) // local_size_out,
dtype=compute_out_dtype,
)
return _warp_mma(A_local_buf, B_local_buf, C_local_buf)
def stmatrix(self, C_local_buf, C_buf, thread_bindings, pid_m=None, pid_n=None):
block_row_warps = self.block_row_warps
block_col_warps = self.block_col_warps
warp_rows = self.warp_rows
warp_cols = self.warp_cols
local_size_out = self.local_size_out
is_global = pid_m is not None and pid_n is not None
BLOCK_M = block_row_warps * warp_rows
BLOCK_N = block_col_warps * warp_cols
M_DIM, N_DIM = self.M_DIM, self.N_DIM
# STS
# MMA Store must be in simulated instead of TVM Intrins
# As TVM Intrins is like a hack that the threadIdx.x should be always
# equal to the warp_size
@T.macro
def _warp_stmatrix_shared(C_local_buf, C_buf, thread_bindings):
tx, warp_n, warp_m = self.extract_thread_binding(thread_bindings)
for i, j in T.grid(warp_rows, warp_cols):
for local_id in T.serial(local_size_out):
row, col = T.meta_var(mfma_store_index_map(tx, local_id))
C_buf[warp_m * warp_rows + i, warp_n * warp_cols + j, row,
col] = C_local_buf[i * warp_cols * local_size_out + j * local_size_out +
local_id]
@T.macro
def _warp_stmatrix_global(C_local_buf, C_buf, thread_bindings):
tx, warp_n, warp_m = self.extract_thread_binding(thread_bindings)
for i, j in T.grid(warp_rows, warp_cols):
for local_id in T.serial(local_size_out):
row, col = T.meta_var(mfma_store_index_map(tx, local_id))
C_buf[(pid_m * BLOCK_M + warp_m * warp_rows + i) * M_DIM + row,
(pid_n * BLOCK_N + warp_n * warp_cols + j) * N_DIM +
col] = C_local_buf[i * warp_cols * local_size_out + j * local_size_out +
local_id]
return _warp_stmatrix_global(C_local_buf, C_buf,
thread_bindings) if is_global else _warp_stmatrix_shared(
C_local_buf, C_buf, thread_bindings)
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from typing import Union
from tvm import arith, DataType
import tilelang.language as T
def ldmatrix_32x8_to_shared_16x16_layout(thread_id, local_id):
row = thread_id % 16
col = 8 * (thread_id // 16) + local_id % 8
return row, col
def ldmatrix_trans_32x8_to_shared_16x16_layout(thread_id, local_id):
row = 8 * (thread_id // 16) + (thread_id % 8)
col = 8 * ((thread_id % 16) // 8) + local_id % 8
return row, col
def ldmatrix_16x32_to_shared_16x32_layout_a(thread_id, local_id):
row = thread_id % 16
col = 16 * (thread_id // 16) + local_id % 16
return row, col
def ldmatrix_16x32_to_shared_16x32_layout_b(thread_id, local_id):
row = 8 * (thread_id // 16) + (thread_id % 8)
col = 16 * ((thread_id % 16) // 8) + local_id % 16
return row, col
def ldmatrix_32x16_to_shared_16x32_layout_a(thread_id, local_id):
row = thread_id % 16
col = local_id + (thread_id // 16) * 16
return row, col
def ldmatrix_32x16_to_shared_16x32_layout_b(thread_id, local_id):
row = (thread_id // 16) * 8 + (thread_id % 8)
col = local_id + 16 * ((thread_id % 16) // 8)
return row, col
def mma_store_32x8_to_shared_16x16_layout(thread_id, local_id):
row = 8 * (local_id % 4 // 2) + (thread_id // 4)
col = 8 * (local_id // 4) + (thread_id % 4) * 2 + (local_id % 2)
return row, col
def shared_16x16_to_mma_32x8_smoothlayout(i, j):
return (i * 2 + j // 8, j % 8)
def shared_16x32_to_mma_32x16_smoothlayout(i, j):
return (i * 2 + j // 16, j % 16)
def shared_32x16_to_mma_32x16_smoothlayout(i, j):
return (i * 2 + j // 16, j % 16)
def get_swizzle_layout(row_idx, col_idx, row_size, dtype: Union[DataType, str]):
ana = arith.Analyzer()
BANK_SIZE_BYTES = 128
if isinstance(dtype, str):
dtype = DataType(dtype)
col_idx_outer, col_idx_inner = col_idx // (BANK_SIZE_BYTES // dtype.bits), col_idx % (
BANK_SIZE_BYTES // dtype.bits)
# use transaction bits to support diverse dtype.
# for fp16, 64 elems * 16 bits = 1024 bits, 32 elems * 32 bits = 512 bits
# for int8, 128 elems * 8 bits = 1024 bits, 64 elems * 8 bits = 512 bits
coalescent_bits = dtype.bits * row_size
# permutation on 4 banks, each bank has 32 bits
bank_elems = BANK_SIZE_BYTES // dtype.bits
new_col_idx_outer = None
if coalescent_bits % 1024 == 0:
# Use 8 * 8 permuted layout
# Every number below corresponds to 8 consecutive fp16 number in shared mem, i.e. one read
# Every row below corresponds to 32 banks
# 0 1 2 3 4 5 6 7 ==> 0 1 2 3 4 5 6 7
# 0 1 2 3 4 5 6 7 ==> 1 0 3 2 5 4 7 6
# 0 1 2 3 4 5 6 7 ==> 2 3 0 1 6 7 4 5
# 0 1 2 3 4 5 6 7 ==> 3 2 1 0 7 6 5 4
# 0 1 2 3 4 5 6 7 ==> 4 5 6 7 0 1 2 3
# 0 1 2 3 4 5 6 7 ==> 5 4 7 6 1 0 3 2
# 0 1 2 3 4 5 6 7 ==> 6 7 4 5 2 3 0 1
# 0 1 2 3 4 5 6 7 ==> 7 6 5 4 3 2 1 0
row_idx_sub = row_idx % bank_elems
new_col_idx_outer = col_idx_outer ^ row_idx_sub
else:
assert coalescent_bits % 512 == 0
# Use 8 * 4 permuted layout
# Every number below corresponds to 8 consecutive fp16 number in shared mem, i.e. one read
# Every row below corresponds to 16 banks
# 0 1 2 3 ==> 0 1 2 3
# 0 1 2 3 ==> 0 1 2 3
# 0 1 2 3 ==> 1 0 3 2
# 0 1 2 3 ==> 1 0 3 2
# 0 1 2 3 ==> 2 3 0 1
# 0 1 2 3 ==> 2 3 0 1
# 0 1 2 3 ==> 3 2 1 0
# 0 1 2 3 ==> 3 2 1 0
# View with 8 elements per row:
# 0 1 2 3 4 0 1 2 3 ==> 0 1 2 3 0 1 2 3
# 0 1 2 3 4 0 1 2 3 ==> 1 0 3 2 1 0 3 2
# 0 1 2 3 4 0 1 2 3 ==> 2 3 0 1 2 3 0 1
# 0 1 2 3 4 0 1 2 3 ==> 3 2 1 0 3 2 1 0
row_idx_sub = row_idx % bank_elems
# Interleave elems per byte
interleave_elems = 32 // dtype.bits
new_col_idx_outer = col_idx_outer ^ (row_idx_sub // interleave_elems)
assert (new_col_idx_outer is not None), f"Unsupported dtype {dtype} with {coalescent_bits} bits"
return row_idx, ana.simplify(new_col_idx_outer * bank_elems + col_idx_inner)
def make_mma_swizzle_layout(shared_buf, is_smooth: bool = False):
dtype = shared_buf.dtype
shape = shared_buf.shape
can_swizzle = shape[-1] * DataType(dtype).bits % 512 == 0
if is_smooth or (not can_swizzle):
return T.Layout(shape, lambda *args: args)
def transform_func(*args):
i, j = args[-2:]
new_warp_i, new_warp_j = get_swizzle_layout(i, j, shape[-1], dtype)
return [*args[:-2], new_warp_i, new_warp_j]
return T.Layout(shape, transform_func)
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import tilelang.language as T
from typing import Union, Tuple, Optional, Literal, Callable
from tilelang.common import TransformKind
from tvm import DataType
from tvm.tir import PrimExpr, IndexMap, Buffer
from tvm.runtime import convert
from .utils import (
mma_store_index_map,
get_ldmatrix_offset,
)
lift = convert
# TODO(lei): Add Typing for this file
class TensorCoreIntrinEmitter(object):
"""
To eliminate Python syntax within TIR Macro.
"""
M_DIM = 16
N_DIM = 16
WARP_SIZE = 32
dtype_abbrv = {
"float16": "fp16",
"bfloat16": "bf16",
"float32": "fp32",
"int8": "int8",
"int32": "int32",
"e4m3_float8": "e4m3",
"e5m2_float8": "e5m2",
}
# Represent the thread binding in the form of (tx, warp_n, warp_m)
is_m_first = False
def __init__(
self,
a_dtype: str = "float16",
b_dtype: str = "float16",
accum_dtype: str = "float16",
a_transposed: bool = False,
b_transposed: bool = False,
block_row_warps: int = 2,
block_col_warps: int = 2,
warp_row_tiles: int = 8,
warp_col_tiles: int = 8,
chunk: int = 16,
reduce_k: int = 1,
num_elems_per_byte: int = 1,
is_m_first: Optional[bool] = False,
):
self.a_dtype = a_dtype
self.b_dtype = b_dtype
self.accum_dtype = accum_dtype
self.a_transposed = a_transposed
self.b_transposed = b_transposed
# Hint Information
self.block_row_warps = block_row_warps
self.block_col_warps = block_col_warps
self.warp_row_tiles = warp_row_tiles
self.warp_col_tiles = warp_col_tiles
self.chunk = chunk
self._initialize_k_dim(a_dtype)
self._initialize_abbrev(a_dtype, b_dtype, accum_dtype)
self._initialize_local_size(self.M_DIM, self.N_DIM, self.k_dim, self.WARP_SIZE)
self._initialize_mma_prefix(self.k_dim)
self._initialize_micro_size(self.M_DIM, self.N_DIM, self.k_dim)
self._initialize_is_m_first(is_m_first)
self.warp_rows = warp_row_tiles // self.micro_size_x
self.warp_cols = warp_col_tiles // self.micro_size_y
self.reduce_k = reduce_k
self.threads = self.WARP_SIZE * (block_row_warps * block_col_warps) * reduce_k
self.num_elems_per_byte = num_elems_per_byte
if self.warp_rows == 0 or self.warp_cols == 0:
raise ValueError(f"Invalid threads configuration for this tile shape, {self.warp_rows} x {self.warp_cols} with threads {self.threads}")
def _initialize_k_dim(self, a_dtype="float16"):
if isinstance(a_dtype, str):
a_dtype = DataType(a_dtype)
self.k_dim = 256 // a_dtype.bits
def _initialize_local_size(self, m_dim=16, n_dim=16, k_dim=16, warp_size=32):
self.local_size_a = (m_dim * k_dim) // warp_size
self.local_size_b = (n_dim * k_dim) // warp_size
self.local_size_out = (m_dim * n_dim) // warp_size
def _initialize_abbrev(self, a_dtype, b_dtype, accum_dtype):
self.a_dtype_abbrv = self.dtype_abbrv[a_dtype]
self.b_dtype_abbrv = self.dtype_abbrv[b_dtype]
self.accum_dtype_abbrv = self.dtype_abbrv[accum_dtype]
def _initialize_mma_prefix(self, k_dim=16):
if k_dim == 16:
self.mma_prefix = "m16n8k16"
elif k_dim == 32:
self.mma_prefix = "m16n8k32"
else:
raise ValueError("Unsupported k_dim")
def _initialize_micro_size(self, m_dim=16, n_dim=16, k_dim=16):
self.micro_size_x = m_dim
self.micro_size_y = n_dim
self.micro_size_k = k_dim
def _initialize_is_m_first(self, is_m_first: Optional[bool] = False):
if is_m_first is not None:
self.is_m_first = is_m_first
def get_store_index_map(self, inverse: bool = False) -> IndexMap:
warp_size, local_size_c = self.WARP_SIZE, self.local_size_out
index_map = IndexMap.from_func(mma_store_index_map, index_dtype="int32")
if not inverse:
return index_map
inverse_index_map = index_map.inverse([warp_size, local_size_c])
return inverse_index_map
def extract_thread_binding(self,
thread_id,
is_m_first=None) -> Tuple[PrimExpr, PrimExpr, PrimExpr]:
"""
is_m_first: True if the thread binding is in the form of (tx, warp_n, warp_m)
which represents [warp_size, block_row_warps (split n), block_col_warps (split m)]
Otherwise, it is in the form of [warp_size, block_col_warps (split m), block_row_warps (split n)]
"""
WARP_SIZE = self.WARP_SIZE
block_row_warps = self.block_row_warps
block_col_warps = self.block_col_warps
# if is_m_first is None, then use the default value
if is_m_first is None:
is_m_first = self.is_m_first
if is_m_first:
lane_id, warp_n, warp_m = (
thread_id % WARP_SIZE,
(thread_id // WARP_SIZE) % block_col_warps,
(thread_id // (WARP_SIZE * block_col_warps)) % block_row_warps,
)
return lane_id, warp_n, warp_m
else:
lane_id, warp_m, warp_n = (
thread_id % WARP_SIZE,
(thread_id // WARP_SIZE) % block_row_warps,
(thread_id // (WARP_SIZE * block_row_warps)) % block_col_warps,
)
return lane_id, warp_n, warp_m
def ldmatrix_a(self, A_local_buf, A_shared_buf, ki, thread_bindings, rk=0):
warp_row_tiles = self.warp_row_tiles
warp_rows = self.warp_rows
chunk = self.chunk
micro_size_x = self.micro_size_x
micro_size_k = self.micro_size_k
local_size_a = self.local_size_a
a_dtype = self.a_dtype
a_transposed = self.a_transposed
@T.macro
def _warp_ldmatrix_a(
A_local_buf,
A_shared_buf,
ki,
thread_bindings,
rk=0,
):
stride = A_shared_buf.shape[-1]
tx, _, warp_m = self.extract_thread_binding(thread_bindings)
for i in T.serial(warp_rows):
T.ptx_ldmatrix(
a_dtype,
T.bool(False),
4,
".b16",
A_local_buf.data,
i * local_size_a,
T.address_of(A_shared_buf[
warp_m * warp_row_tiles + i * micro_size_x,
rk * chunk + ki * micro_size_k,
]),
get_ldmatrix_offset("A", tx, 0, stride, a_dtype, a_transposed),
)
return _warp_ldmatrix_a(A_local_buf, A_shared_buf, ki, thread_bindings, rk)
def ldmatrix_b(self, B_local_buf, B_shared_buf, ki, thread_bindings, rk=0):
warp_col_tiles = self.warp_col_tiles
warp_cols = self.warp_cols
chunk = self.chunk
micro_size_y = self.micro_size_y
micro_size_k = self.micro_size_k
local_size_b = self.local_size_b
b_dtype = self.b_dtype
b_transposed = self.b_transposed
@T.macro
def _warp_ldmatrix_b(
B_local_buf,
B_shared_buf,
ki,
thread_bindings,
rk=0,
):
stride = B_shared_buf.shape[-1]
tx, warp_n, _ = self.extract_thread_binding(thread_bindings)
for j in T.serial(warp_cols):
# Assign B_shared_elem
ri, rj = (
warp_n * warp_col_tiles + j * micro_size_y,
rk * chunk + ki * micro_size_k,
)
B_shared_elem = B_shared_buf[ri, rj]
T.ptx_ldmatrix(
b_dtype,
T.bool(False), # TODO(lei): should be optimized
4,
".b16",
B_local_buf.data,
j * local_size_b,
T.address_of(B_shared_elem),
get_ldmatrix_offset("B", tx, 0, stride, b_dtype, b_transposed),
)
return _warp_ldmatrix_b(B_local_buf, B_shared_buf, ki, thread_bindings, rk)
def mma(self, A_local_buf, B_local_buf, C_local_buf, k_inner=0):
warp_rows = self.warp_rows
warp_cols = self.warp_cols
local_size_a = self.local_size_a
local_size_b = self.local_size_b
local_size_out = self.local_size_out
a_dtype_abbrv = self.a_dtype_abbrv
b_dtype_abbrv = self.b_dtype_abbrv
accum_dtype = self.accum_dtype
accum_dtype_abbrv = self.accum_dtype_abbrv
mma_prefix = self.mma_prefix
@T.macro
def _warp_mma(A_local_buf, B_local_buf, C_local_buf):
for i, j in T.grid(warp_rows, warp_cols):
T.ptx_mma(
accum_dtype,
mma_prefix,
"row",
"col",
a_dtype_abbrv,
b_dtype_abbrv,
accum_dtype_abbrv,
A_local_buf.data,
k_inner * warp_rows * local_size_a + i * local_size_a,
B_local_buf.data,
k_inner * warp_cols * local_size_b + j * local_size_b,
C_local_buf.data,
i * warp_cols * local_size_out + j * local_size_out,
T.bool(False),
)
T.ptx_mma(
accum_dtype,
mma_prefix,
"row",
"col",
a_dtype_abbrv,
b_dtype_abbrv,
accum_dtype_abbrv,
A_local_buf.data,
k_inner * warp_rows * local_size_a + i * local_size_a,
B_local_buf.data,
k_inner * warp_cols * local_size_b + j * local_size_b
+ lift(local_size_b) // 2,
C_local_buf.data,
i * warp_cols * local_size_out
+ j * local_size_out
+ lift(local_size_out) // 2,
T.bool(False),
)
return _warp_mma(A_local_buf, B_local_buf, C_local_buf)
def stmatrix(self, C_local_buf, C_buf, thread_bindings, pid_m=None, pid_n=None):
block_row_warps = self.block_row_warps
block_col_warps = self.block_col_warps
warp_rows = self.warp_rows
warp_cols = self.warp_cols
local_size_out = self.local_size_out
is_global = pid_m is not None and pid_n is not None
BLOCK_M = block_row_warps * warp_rows
BLOCK_N = block_col_warps * warp_cols
M_DIM, N_DIM = self.M_DIM, self.N_DIM
# STS
# MMA Store must be in simulated instead of TVM Intrins
# As TVM Intrins is like a hack that the threadIdx.x should be always
# equal to the warp_size
@T.macro
def _warp_stmatrix_shared(C_local_buf, C_buf, thread_bindings):
tx, warp_n, warp_m = self.extract_thread_binding(thread_bindings)
for i, j in T.grid(warp_rows, warp_cols):
for local_id_o in T.serial(local_size_out // 2):
for local_id_i in T.vectorized(2):
local_id = local_id_o * 2 + local_id_i
row, col = T.meta_var(mma_store_index_map(tx, local_id))
C_buf[warp_m * warp_rows + i, warp_n * warp_cols + j, row,
col] = C_local_buf[i * (warp_cols * local_size_out) +
j * local_size_out + local_id]
@T.macro
def _warp_stmatrix_global(C_local_buf, C_buf, thread_bindings):
tx, warp_n, warp_m = self.extract_thread_binding(thread_bindings)
for i, j in T.grid(warp_rows, warp_cols):
for local_id_o in T.serial(local_size_out // 2):
for local_id_i in T.vectorized(2):
local_id = local_id_o * 2 + local_id_i
row, col = T.meta_var(mma_store_index_map(tx, local_id))
C_buf[
(pid_m * BLOCK_M + warp_m * warp_rows + i) * M_DIM + row,
(pid_n * BLOCK_N + warp_n * warp_cols + j) * N_DIM + col,
] = C_local_buf[i * warp_cols * local_size_out + j * local_size_out +
local_id]
return (_warp_stmatrix_global(C_local_buf, C_buf, thread_bindings)
if is_global else _warp_stmatrix_shared(C_local_buf, C_buf, thread_bindings))
def make_mma_load_layout(self, local_buf: Buffer, matrix:Literal["A", "B"]="A") -> T.Fragment:
"""
Create a layout function for storing MMA results into a fragment buffer.
This layout is used in conjunction with `inverse_mma_store_layout` to
map fragment indices to threads and local indices.
Parameters
----------
local_buf : tir.Buffer
The local buffer representing a fragment of a matrix.
Returns
-------
T.Fragment
A fragment object that describes how threads and indices
in `local_buf` are laid out.
Raises
------
AssertionError
If `local_buf` is not detected to be a fragment buffer.
"""
from tilelang.primitives.utils import is_fragment
from tilelang.intrinsics.mma_layout import (
ldmatrix_32x8_to_shared_16x16_layout,
ldmatrix_trans_32x8_to_shared_16x16_layout,
ldmatrix_16x32_to_shared_16x32_layout_a,
ldmatrix_16x32_to_shared_16x32_layout_b,
)
assert matrix in ["A", "B"], "matrix should be either A or B"
dtype = self.a_dtype if matrix == "A" else self.b_dtype
dtype_bits = DataType(dtype).bits
transposed = self.a_transposed
transform_func: Callable = None
transform_func_trans: Callable = None
if dtype_bits == 16:
transform_func = ldmatrix_32x8_to_shared_16x16_layout
transform_func_trans = ldmatrix_trans_32x8_to_shared_16x16_layout
elif dtype_bits == 8:
if matrix == "B" and transposed:
transform_func = ldmatrix_16x32_to_shared_16x32_layout_b
elif matrix == "A" and not transposed:
transform_func = ldmatrix_16x32_to_shared_16x32_layout_a
else:
raise ValueError("ldmatrix only supports B transposed and A non-transposed for int8")
else:
raise ValueError(f"Unsupported dtype {dtype}")
shape = local_buf.shape
assert is_fragment(local_buf), "local_buf must be a fragment, but got {}".format(local_buf.scope())
if matrix == "A":
micro_size_x, micro_size_y = self.micro_size_x, self.micro_size_k
else:
micro_size_x, micro_size_y = self.micro_size_k, self.micro_size_y
if transposed:
micro_size_x, micro_size_y = micro_size_y, micro_size_x
local_size_out = self.local_size_out
block_row_warps, block_col_warps = (
self.block_row_warps,
self.block_col_warps,
)
warp_rows, warp_cols = self.warp_rows, self.warp_cols
warp_size = self.WARP_SIZE
is_m_first = self.is_m_first
transform_func = transform_func if not transposed else transform_func_trans
warp_size, local_size_a, local_size_b = self.WARP_SIZE, self.local_size_a, self.local_size_b
local_size = local_size_a if matrix == "A" else local_size_b
inverse_mma_load_layout = IndexMap.from_func(transform_func, index_dtype="int32").inverse([warp_size, local_size])
def forward_thread(i: int, j: int) -> int:
"""
Given the row index `i` and column index `j` in the fragment,
map them to a thread index according to `inverse_mma_store_layout`.
"""
# the upper bounds of i and j are block_row_warps * warp_rows * micro_size_x and block_col_warps * warp_cols * micro_size_y
# the upper bounds of block_row_warps and block_col_warps are warp_rows and warp_cols
block_i, block_j = (i // micro_size_x) // warp_rows, (
j // micro_size_y
) // warp_cols
# the upper bounds of warp_i and warp_j are warp_rows and warp_cols
warp_i, warp_j = (i // micro_size_x) % warp_rows, (
j // micro_size_y
) % warp_cols
# upper bounds of mma_i and mma_j are micro_size_x and micro_size_y
mma_i, mma_j = i % micro_size_x, j % micro_size_y
lane_id, _ = inverse_mma_load_layout.map_indices([mma_i, mma_j])
if is_m_first:
thread_id = (
block_i * (block_col_warps * warp_cols)
+ block_j * warp_rows
+ warp_i * warp_cols
+ warp_j
)
else:
thread_id = (
block_j * (block_row_warps * warp_size)
+ block_i * warp_size
+ lane_id
)
return thread_id
def forward_index(i: int, j: int) -> int:
"""
Given the row index `i` and column index `j` in the fragment,
map them to a local index in a single thread according
to `inverse_mma_store_layout`.
"""
# the upper bounds of i and j are block_row_warps * warp_rows * micro_size_x and block_col_warps * warp_cols * micro_size_y
# the upper bounds of block_row_warps and block_col_warps are warp_rows and warp_cols
block_i, block_j = (i // micro_size_x) // warp_rows, (
j // micro_size_y
) // warp_cols
# the upper bounds of warp_i and warp_j are warp_rows and warp_cols
warp_i, warp_j = (i // micro_size_x) % warp_rows, (
j // micro_size_y
) % warp_cols
# upper bounds of mma_i and mma_j are micro_size_x and micro_size_y
mma_i, mma_j = i % micro_size_x, j % micro_size_y
_, local_id = inverse_mma_load_layout.map_indices([mma_i, mma_j])
return (
warp_i * (warp_cols * local_size_out)
+ warp_j * local_size_out
+ local_id
)
fragment = T.Fragment(
shape,
forward_thread_fn=forward_thread,
forward_index_fn=forward_index,
)
print(f"fragment.shape: {local_buf.shape}")
print(f"fragment.thread: {fragment.thread}")
print(f"fragment.index: {fragment.index}")
return fragment
def make_mma_store_layout(
self, local_buf: Buffer
) -> T.Fragment:
"""
Create a layout function for storing MMA results into a fragment buffer.
This layout is used in conjunction with `inverse_mma_store_layout` to
map fragment indices to threads and local indices.
Parameters
----------
local_buf : tir.Buffer
The local buffer representing a fragment of a matrix.
Returns
-------
T.Fragment
A fragment object that describes how threads and indices
in `local_buf` are laid out.
Raises
------
AssertionError
If `local_buf` is not detected to be a fragment buffer.
"""
from tilelang.primitives.utils import is_fragment
shape = local_buf.shape
inverse_mma_store_layout = self.get_store_index_map(inverse=True)
assert is_fragment(local_buf), "local_buf must be a fragment"
micro_size_x, micro_size_y = self.micro_size_x, self.micro_size_y
local_size_out = self.local_size_out
block_row_warps, block_col_warps = self.block_row_warps, self.block_col_warps
warp_rows, warp_cols = self.warp_rows, self.warp_cols
warp_size = self.WARP_SIZE
is_m_first = self.is_m_first
def forward_thread(i: int, j: int) -> int:
"""
Given the row index `i` and column index `j` in the fragment,
map them to a thread index according to `inverse_mma_store_layout`.
"""
# the upper bounds of i and j are block_row_warps * warp_rows * micro_size_x and block_col_warps * warp_cols * micro_size_y
# the upper bounds of block_row_warps and block_col_warps are warp_rows and warp_cols
block_i, block_j = (i // micro_size_x) // warp_rows, (j // micro_size_y) // warp_cols
# the upper bounds of warp_i and warp_j are warp_rows and warp_cols
warp_i, warp_j = (i // micro_size_x) % warp_rows, (j // micro_size_y) % warp_cols
# upper bounds of mma_i and mma_j are micro_size_x and micro_size_y
mma_i, mma_j = i % micro_size_x, j % micro_size_y
lane_id, _ = inverse_mma_store_layout.map_indices([mma_i, mma_j])
if is_m_first:
thread_id = block_i * (block_col_warps * warp_cols) + block_j * warp_rows + warp_i * warp_cols + warp_j
else:
thread_id = block_j * (block_row_warps * warp_size) + block_i * warp_size + lane_id
return thread_id
def forward_index(i: int, j: int) -> int:
"""
Given the row index `i` and column index `j` in the fragment,
map them to a local index in a single thread according
to `inverse_mma_store_layout`.
"""
# the upper bounds of i and j are block_row_warps * warp_rows * micro_size_x and block_col_warps * warp_cols * micro_size_y
# the upper bounds of block_row_warps and block_col_warps are warp_rows and warp_cols
block_i, block_j = (i // micro_size_x) // warp_rows, (
j // micro_size_y
) // warp_cols
# the upper bounds of warp_i and warp_j are warp_rows and warp_cols
warp_i, warp_j = (i // micro_size_x) % warp_rows, (
j // micro_size_y
) % warp_cols
# upper bounds of mma_i and mma_j are micro_size_x and micro_size_y
mma_i, mma_j = i % micro_size_x, j % micro_size_y
_, local_id = inverse_mma_store_layout.map_indices([mma_i, mma_j])
return warp_i * (warp_cols * local_size_out) + warp_j * local_size_out + local_id
return T.Fragment(
shape,
forward_thread_fn=forward_thread,
forward_index_fn=forward_index,
)
class TensorCoreIntrinEmitterWithLadderTransform(TensorCoreIntrinEmitter):
"""
To eliminate Python syntax within TIR Macro.
With Ladder Transform Plugin.
"""
def __init__(
self,
a_dtype: str = "float16",
b_dtype: str = "float16",
accum_dtype: str = "float16",
a_transposed: bool = False,
b_transposed: bool = False,
block_row_warps: int = 2,
block_col_warps: int = 2,
warp_row_tiles: int = 8,
warp_col_tiles: int = 8,
chunk: int = 16,
reduce_k: int = 1,
num_elems_per_byte: int = 1,
is_m_first: Optional[bool] = False,
transform_kind_a: Union[int, TransformKind] = 0,
transform_kind_b: Union[int, TransformKind] = 0,
):
super().__init__(
a_dtype=a_dtype,
b_dtype=b_dtype,
accum_dtype=accum_dtype,
a_transposed=a_transposed,
b_transposed=b_transposed,
block_row_warps=block_row_warps,
block_col_warps=block_col_warps,
warp_row_tiles=warp_row_tiles,
warp_col_tiles=warp_col_tiles,
chunk=chunk,
reduce_k=reduce_k,
num_elems_per_byte=num_elems_per_byte,
is_m_first=is_m_first,
)
self._initialize_transform_kind(transform_kind_a, transform_kind_b)
def _initialize_k_dim(self, a_dtype="float16"):
self.k_dim = 256 // DataType(a_dtype).bits
def _initialize_local_size(self, m_dim=16, n_dim=16, k_dim=16, warp_size=32):
self.local_size_a = (m_dim * k_dim) // warp_size
self.local_size_b = (n_dim * k_dim) // warp_size
self.local_size_out = (m_dim * n_dim) // warp_size
def _initialize_abbrev(self, a_dtype, b_dtype, accum_dtype):
self.a_dtype_abbrv = self.dtype_abbrv[a_dtype]
self.b_dtype_abbrv = self.dtype_abbrv[b_dtype]
self.accum_dtype_abbrv = self.dtype_abbrv[accum_dtype]
def _initialize_mma_prefix(self, k_dim=16):
if k_dim == 16:
self.mma_prefix = "m16n8k16"
elif k_dim == 32:
self.mma_prefix = "m16n8k32"
else:
raise ValueError("Unsupported k_dim")
def _initialize_micro_size(self, m_dim=16, n_dim=16, k_dim=16):
self.micro_size_x = m_dim
self.micro_size_y = n_dim
self.micro_size_k = k_dim
def _initialize_transform_kind(self, transform_kind_a, transform_kind_b):
if isinstance(transform_kind_a, int):
self.transform_kind_a = TransformKind(transform_kind_a)
elif isinstance(transform_kind_a, TransformKind):
self.transform_kind_a = transform_kind_a
else:
raise ValueError("Unsupported transform_kind_a")
if isinstance(transform_kind_b, int):
self.transform_kind_b = TransformKind(transform_kind_b)
elif isinstance(transform_kind_b, TransformKind):
self.transform_kind_b = transform_kind_b
else:
raise ValueError("Unsupported transform_kind_b")
assert transform_kind_a in [0, 1, 2, 3], "Input transform stage should be 0, 1, 2, or 3"
assert transform_kind_b in [0, 1, 2, 3], "Weight transform stage should be 0, 1, 2, or 3"
def ldmatrix_a(self, A_local_buf, A_shared_buf, ki, thread_bindings, rk=0):
warp_row_tiles = self.warp_row_tiles
warp_rows = self.warp_rows
chunk = self.chunk
micro_size_x = self.micro_size_x
micro_size_k = self.micro_size_k
local_size_a = self.local_size_a
a_dtype = self.a_dtype
a_transposed = self.a_transposed
transform_kind_a = self.transform_kind_a
@T.macro
def _warp_ldmatrix_a(
A_local_buf,
A_shared_buf,
ki,
thread_bindings,
rk=0,
):
stride = A_shared_buf.shape[-1]
tx, _, warp_m = self.extract_thread_binding(thread_bindings)
if transform_kind_a == TransformKind.NonTransform:
for i in T.serial(warp_rows):
T.ptx_ldmatrix(
a_dtype,
T.bool(False),
4,
".b16",
A_local_buf.data,
i * local_size_a,
T.address_of(A_shared_buf[
warp_m * warp_row_tiles + i * micro_size_x,
rk * chunk + ki * micro_size_k,
]),
get_ldmatrix_offset("A", tx, 0, stride, a_dtype, a_transposed),
)
elif transform_kind_a == TransformKind.InterWarpTransform:
for i in T.serial(warp_rows):
# Assign B_shared_elem
ri, rj = (
warp_m * warp_row_tiles + i * micro_size_x,
rk * chunk + ki * micro_size_k,
)
ni, nj, nii, njj = (
(ri) // micro_size_x,
(rj) // micro_size_k,
(ri) % micro_size_x,
(rj) % micro_size_k,
)
args = (ni, nj, nii, njj) if transform_kind_a > 0 else (ri, rj)
A_shared_elem = A_shared_buf[args]
T.ptx_ldmatrix(
a_dtype,
T.bool(False),
4,
".b16",
A_local_buf.data,
i * local_size_a,
T.address_of(A_shared_elem),
get_ldmatrix_offset("A", tx, 0, stride, a_dtype, a_transposed),
)
elif transform_kind_a == TransformKind.IntraWarpTransform:
for i in T.serial(warp_rows):
# Assign B_shared_elem
ri, rj = (
warp_m * warp_row_tiles + i * micro_size_x,
rk * chunk + ki * micro_size_k,
)
ni, nj, nii, njj = (
(ri) // micro_size_x,
(rj) // micro_size_k,
(ri) % micro_size_x,
(rj) % micro_size_k,
)
A_shared_elem = A_shared_buf[ni, nj, nii, njj]
T.ptx_ldmatrix(
a_dtype,
T.bool(False),
4,
".b16",
A_local_buf.data,
i * local_size_a,
T.address_of(A_shared_elem),
tx * local_size_a,
)
elif transform_kind_a == TransformKind.LDMatrixTransform:
for j in T.serial(warp_rows):
for local_id in T.vectorized(local_size_a):
# Assign A_shared_elem
ri, rj = (
warp_m * warp_rows + j,
rk * (chunk // micro_size_k) + ki,
)
rii, rjj = (tx * local_size_a +
local_id) // micro_size_k, (tx * local_size_a + local_id) % (
micro_size_k)
A_local_buf[j * local_size_a + local_id] = (A_shared_buf[ri, rj, rii, rjj])
else:
raise ValueError("Unsupported TransformKind for Input A")
return _warp_ldmatrix_a(A_local_buf, A_shared_buf, ki, thread_bindings, rk)
def ldmatrix_b(self, B_local_buf, B_shared_buf, ki, thread_bindings, rk=0):
warp_col_tiles = self.warp_col_tiles
warp_cols = self.warp_cols
chunk = self.chunk
micro_size_y = self.micro_size_y
micro_size_k = self.micro_size_k
local_size_b = self.local_size_b
b_dtype = self.b_dtype
transform_kind_b = self.transform_kind_b
b_transposed = self.b_transposed
num_elems_per_byte = self.num_elems_per_byte
@T.macro
def _warp_ldmatrix_b(
B_local_buf,
B_shared_buf,
ki,
thread_bindings,
rk=0,
):
stride = B_shared_buf.shape[-1]
tx, warp_n, _ = self.extract_thread_binding(thread_bindings)
if transform_kind_b == TransformKind.NonTransform:
for j in T.serial(warp_cols):
# Assign B_shared_elem
ri, rj = (
warp_n * warp_col_tiles + j * micro_size_y,
rk * chunk + ki * micro_size_k,
)
B_shared_elem = B_shared_buf[ri, rj]
T.ptx_ldmatrix(
b_dtype,
T.bool(False),
4,
".b16",
B_local_buf.data,
j * local_size_b,
T.address_of(B_shared_elem),
get_ldmatrix_offset("B", tx, 0, stride, b_dtype, b_transposed),
)
elif transform_kind_b == TransformKind.InterWarpTransform:
for j in T.serial(warp_cols):
# Assign B_shared_elem
ri, rj = (
warp_n * warp_col_tiles + j * micro_size_y,
rk * chunk + ki * micro_size_k,
)
ni, nj, nii, njj = (
(ri) // micro_size_y,
(rj) // micro_size_k,
(ri) % micro_size_y,
(rj) % micro_size_k,
)
B_shared_elem = B_shared_buf[ni, nj, nii, njj]
T.ptx_ldmatrix(
b_dtype,
T.bool(False), # TODO(lei): should be optimized
4,
".b16",
B_local_buf.data,
j * local_size_b,
T.address_of(B_shared_elem),
get_ldmatrix_offset("B", tx, 0, stride, b_dtype, b_transposed),
)
elif transform_kind_b == TransformKind.IntraWarpTransform:
for j in T.serial(warp_cols):
# Assign B_shared_elem
ri, rj = (
warp_n * warp_col_tiles + j * micro_size_y,
rk * chunk + ki * micro_size_k,
)
ni, nj, nii, njj = (
(ri) // micro_size_y,
(rj) // micro_size_k,
(ri) % micro_size_y,
(rj) % micro_size_k,
)
B_shared_elem = B_shared_buf[ni, nj, nii, njj]
T.ptx_ldmatrix(
b_dtype,
T.bool(False), # TODO(lei): should be optimized
4,
".b16",
B_local_buf.data,
j * local_size_b,
T.address_of(B_shared_elem),
tx * local_size_b,
)
elif transform_kind_b == TransformKind.LDMatrixTransform:
local_size_dequantize = local_size_b // num_elems_per_byte
for j in T.serial(warp_cols):
for local_id in T.vectorized(local_size_dequantize):
# Assign B_shared_elem
ri, rj = (
warp_n * warp_cols + j,
rk * (chunk // micro_size_k) + ki,
)
rii, rjj = (tx * local_size_dequantize +
local_id) // (micro_size_k // num_elems_per_byte), (
tx * local_size_dequantize + local_id) % (
micro_size_k // num_elems_per_byte)
B_local_buf[j * local_size_dequantize + local_id] = (
B_shared_buf[ri, rj, rii, rjj])
else:
raise ValueError("Unsupported TransformKind for Input B")
return _warp_ldmatrix_b(B_local_buf, B_shared_buf, ki, thread_bindings, rk)
def mma(self, A_local_buf, B_local_buf, C_local_buf):
warp_rows = self.warp_rows
warp_cols = self.warp_cols
local_size_a = self.local_size_a
local_size_b = self.local_size_b
local_size_out = self.local_size_out
a_dtype_abbrv = self.a_dtype_abbrv
b_dtype_abbrv = self.b_dtype_abbrv
accum_dtype = self.accum_dtype
accum_dtype_abbrv = self.accum_dtype_abbrv
mma_prefix = self.mma_prefix
@T.macro
def _warp_mma(A_local_buf, B_local_buf, C_local_buf):
for i, j in T.grid(warp_rows, warp_cols):
T.ptx_mma(
accum_dtype,
mma_prefix,
"row",
"col",
a_dtype_abbrv,
b_dtype_abbrv,
accum_dtype_abbrv,
A_local_buf.data,
i * local_size_a,
B_local_buf.data,
j * local_size_b,
C_local_buf.data,
i * warp_cols * local_size_out + j * local_size_out,
T.bool(False),
)
T.ptx_mma(
accum_dtype,
mma_prefix,
"row",
"col",
a_dtype_abbrv,
b_dtype_abbrv,
accum_dtype_abbrv,
A_local_buf.data,
i * local_size_a,
B_local_buf.data,
j * local_size_b + lift(local_size_b) // 2,
C_local_buf.data,
i * warp_cols * local_size_out + j * local_size_out + lift(local_size_out) // 2,
T.bool(False),
)
return _warp_mma(A_local_buf, B_local_buf, C_local_buf)
class INT4TensorCoreIntrinEmitter(TensorCoreIntrinEmitter):
def mma(self, A_local_buf, B_local_buf, C_local_buf):
warp_rows = self.warp_rows
warp_cols = self.warp_cols
local_size_a = self.local_size_a
local_size_b = self.local_size_b
local_size_out = self.local_size_out
a_dtype_abbrv = "int4"
b_dtype_abbrv = "int4"
accum_dtype = self.accum_dtype
accum_dtype_abbrv = accum_dtype
mma_prefix = "m16n8k32"
@T.macro
def _warp_mma(A_local_buf, B_local_buf, C_local_buf):
for i, j in T.grid(warp_rows, warp_cols):
"""
A[16, 32], B[16, 32], C[16, 16]
A_local_size -> 16
B_local_size -> 16
C_local_size -> 8
For each m16n8k32 inst
For A: m16k32 consume 16 int4 elements -> 8 A_local_size
For A: n8k32 consume 8 int4 elements -> 4 B_local_size
For C: m16n8 consume 4 int32 elements -> 4 C_local_size
"""
# A[0:16, 0:16] * B[0:8, 0:16] -> C[0:16, 0:8]
T.ptx_mma(
accum_dtype,
mma_prefix,
"row",
"col",
a_dtype_abbrv,
b_dtype_abbrv,
accum_dtype_abbrv,
A_local_buf.data,
i * local_size_a,
B_local_buf.data,
j * local_size_b,
C_local_buf.data,
i * warp_cols * local_size_out + j * local_size_out,
T.bool(False),
)
# A[0:16, 0:16] * B[8:16, 0:16] -> C[0:16, 8:16]
T.ptx_mma(
accum_dtype,
mma_prefix,
"row",
"col",
a_dtype_abbrv,
b_dtype_abbrv,
accum_dtype_abbrv,
A_local_buf.data,
i * local_size_a,
B_local_buf.data,
j * local_size_b + lift(local_size_b) // 2,
C_local_buf.data,
i * warp_cols * local_size_out + j * local_size_out + lift(local_size_out) // 2,
T.bool(False),
)
# A[0:16, 16:32] * B[0:8, 16:32] -> C[0:16, 0:8]
T.ptx_mma(
accum_dtype,
mma_prefix,
"row",
"col",
a_dtype_abbrv,
b_dtype_abbrv,
accum_dtype_abbrv,
A_local_buf.data,
i * local_size_a + lift(local_size_a) // 2,
B_local_buf.data,
j * local_size_b + lift(local_size_b) // 4,
C_local_buf.data,
i * warp_cols * local_size_out + j * local_size_out,
T.bool(False),
)
# A[0:16, 16:32] * B[8:16, 16:32] -> C[0:16, 8:16]
T.ptx_mma(
accum_dtype,
mma_prefix,
"row",
"col",
a_dtype_abbrv,
b_dtype_abbrv,
accum_dtype_abbrv,
A_local_buf.data,
i * local_size_a + lift(local_size_b) // 2,
B_local_buf.data,
j * local_size_b + lift(local_size_b) // 2 + lift(local_size_b) // 4,
C_local_buf.data,
i * warp_cols * local_size_out + j * local_size_out + lift(local_size_out) // 2,
T.bool(False),
)
return _warp_mma(A_local_buf, B_local_buf, C_local_buf)
class INT4TensorCoreIntrinEmitterWithLadderTransform(TensorCoreIntrinEmitterWithLadderTransform):
def mma(self, A_local_buf, B_local_buf, C_local_buf):
warp_rows = self.warp_rows
warp_cols = self.warp_cols
local_size_a = self.local_size_a
local_size_b = self.local_size_b
local_size_out = self.local_size_out
a_dtype_abbrv = "int4"
b_dtype_abbrv = "int4"
accum_dtype = self.accum_dtype
accum_dtype_abbrv = "int32"
mma_prefix = "m16n8k32"
@T.macro
def _warp_mma(A_local_buf, B_local_buf, C_local_buf):
for i, j in T.grid(warp_rows, warp_cols):
"""
A[16, 32], B[16, 32], C[16, 16]
A_local_size -> 16
B_local_size -> 16
C_local_size -> 8
For each m16n8k32 inst
For A: m16k32 consume 16 int4 elements -> 8 A_local_size
For A: n8k32 consume 8 int4 elements -> 4 B_local_size
For C: m16n8 consume 4 int32 elements -> 4 C_local_size
"""
# A[0:16, 0:16] * B[0:8, 0:16] -> C[0:16, 0:8]
T.ptx_mma(
accum_dtype,
mma_prefix,
"row",
"col",
a_dtype_abbrv,
b_dtype_abbrv,
accum_dtype_abbrv,
A_local_buf.data,
i * local_size_a,
B_local_buf.data,
j * local_size_b,
C_local_buf.data,
i * warp_cols * local_size_out + j * local_size_out,
T.bool(False),
)
# A[0:16, 0:16] * B[8:16, 0:16] -> C[0:16, 8:16]
T.ptx_mma(
accum_dtype,
mma_prefix,
"row",
"col",
a_dtype_abbrv,
b_dtype_abbrv,
accum_dtype_abbrv,
A_local_buf.data,
i * local_size_a,
B_local_buf.data,
j * local_size_b + lift(local_size_b) // 2,
C_local_buf.data,
i * warp_cols * local_size_out + j * local_size_out + lift(local_size_out) // 2,
T.bool(False),
)
# A[0:16, 16:32] * B[0:8, 16:32] -> C[0:16, 0:8]
T.ptx_mma(
accum_dtype,
mma_prefix,
"row",
"col",
a_dtype_abbrv,
b_dtype_abbrv,
accum_dtype_abbrv,
A_local_buf.data,
i * local_size_a + lift(local_size_a) // 2,
B_local_buf.data,
j * local_size_b + lift(local_size_b) // 4,
C_local_buf.data,
i * warp_cols * local_size_out + j * local_size_out,
T.bool(False),
)
# A[0:16, 16:32] * B[8:16, 16:32] -> C[0:16, 8:16]
T.ptx_mma(
accum_dtype,
mma_prefix,
"row",
"col",
a_dtype_abbrv,
b_dtype_abbrv,
accum_dtype_abbrv,
A_local_buf.data,
i * local_size_a + lift(local_size_b) // 2,
B_local_buf.data,
j * local_size_b + lift(local_size_b) // 2 + lift(local_size_b) // 4,
C_local_buf.data,
i * warp_cols * local_size_out + j * local_size_out + lift(local_size_out) // 2,
T.bool(False),
)
return _warp_mma(A_local_buf, B_local_buf, C_local_buf)
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from tvm import DataType
from typing import Literal
from .mma_layout import (
ldmatrix_32x8_to_shared_16x16_layout,
ldmatrix_trans_32x8_to_shared_16x16_layout,
ldmatrix_16x32_to_shared_16x32_layout_a,
ldmatrix_16x32_to_shared_16x32_layout_b,
mma_store_32x8_to_shared_16x16_layout,
)
from .mfma_layout import (thread_id_shared_access_64x4_to_16x16_layout_C_n_m)
from .mma_layout import get_swizzle_layout # noqa: F401
from .mma_layout import make_mma_swizzle_layout # noqa: F401
from .mfma_layout import make_mfma_swizzle_layout # noqa: F401
# the original implementation and insight is from the following code snippet
# 3rdparty/tvm/python/tvm/tir/tensor_intrin/cuda.py#get_ldmatrix_intrin
def get_ldmatrix_offset(
matrix: Literal["A", "B"],
row_idx,
col_idx,
stride,
dtype: Literal["float16", "int8"] = "float16",
transposed: bool = False,
):
assert matrix in ["A", "B"], "matrix should be either A or B"
dtype_bits = DataType(dtype).bits
if dtype_bits == 16:
transform_func = ldmatrix_32x8_to_shared_16x16_layout
transform_func_trans = ldmatrix_trans_32x8_to_shared_16x16_layout
if transposed:
new_row_idx, new_col_idx = transform_func_trans(row_idx, col_idx)
return new_row_idx * stride + new_col_idx
else:
new_row_idx, new_col_idx = transform_func(row_idx, col_idx)
return new_row_idx * stride + new_col_idx
elif dtype_bits == 8:
if matrix == "B" and transposed:
transform_func = ldmatrix_16x32_to_shared_16x32_layout_b
new_row_idx, new_col_idx = transform_func(row_idx, col_idx)
return new_row_idx * stride + new_col_idx
elif matrix == "A" and not transposed:
transform_func = ldmatrix_16x32_to_shared_16x32_layout_a
new_row_idx, new_col_idx = transform_func(row_idx, col_idx)
return new_row_idx * stride + new_col_idx
else:
raise ValueError("ldmatrix only supports B transposed and A non-transposed for int8")
else:
raise ValueError(f"Unsupported dtype {dtype}")
def shared_16x16_to_mma_32x8_layout(i, j):
thread_id = 4 * (i % 8) + (j % 8) // 2
return thread_id, 4 * (j // 8) + (i // 8) * 2 + (j % 2)
def shared_16x32_to_mma_32x16_layout(i, j):
thread_id = 4 * (i % 8) + (j % 16) // 4
return thread_id, 8 * (j // 16) + (i // 8) * 4 + j % 4
def shared_32x16_to_mma_32x16_layout(i, j):
thread_id = (i % 16) // 4 + 4 * (j % 8)
return thread_id, 8 * (j // 8) + (i // 16) * 4 + i % 4
def mma_store_index_map(thread_id, local_id):
return mma_store_32x8_to_shared_16x16_layout(thread_id, local_id)
def mfma_store_index_map(thread_id, local_id):
return thread_id_shared_access_64x4_to_16x16_layout_C_n_m(thread_id, local_id)
def get_mma_micro_size(dtype: Literal["float16", "int8"]):
# TODO(lei): FP8 related precision support.
# Basic Tensor Core Matrix Multiply operation Unit
micro_size_x = micro_size_y = 16
micro_size_k = 16
if dtype == "int8":
micro_size_k = 32
return micro_size_x, micro_size_y, micro_size_k
def index_to_coordinates(index, shape):
'''
General Implementation of:
vjj = index % (micro_size_k // num_elems_per_byte)
coordinates[-1] = index % shape[-1];
vii = index // (micro_size_k // num_elems_per_byte) % micro_size_y
index = index // shape[-1]; coordinates[-2] = index % shape[-2];
vj = index // (micro_size_k // num_elems_per_byte * micro_size_y) % block_K // (micro_size_k // num_elems_per_byte)
index = index // shape[-2]; coordinates[-3] = index % shape[-3];
vi = index // (micro_size_k // num_elems_per_byte * micro_size_y * (block_K // (micro_size_k // num_elems_per_byte))) % block_N // micro_size_y
index = index // shape[-3]; coordinates[-4] = index % shape[-4];
'''
coordinates = []
dims = len(shape)
for i in range(dims):
coordinates.append(index % shape[dims - i - 1])
index = index // shape[dims - i - 1]
coordinates.reverse()
return coordinates
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""The language interface for tl programs."""
from typing import Optional
from tvm.script import tir as T
from tvm.script.parser.tir import *
from tilelang.layout import Layout, Fragment # noqa: F401
from .parallel import Parallel # noqa: F401
from .pipeline import Pipelined # noqa: F401
from .kernel import Kernel # noqa: F401
from .allocate import (
alloc_local, # noqa: F401
alloc_shared, # noqa: F401
alloc_fragment, # noqa: F401
)
from .copy import copy, c2d_im2col # noqa: F401
from .gemm import GemmWarpPolicy, gemm # noqa: F401
from .fill import fill, clear # noqa: F401
from .reduce import (
reduce, # noqa: F401
reduce_max, # noqa: F401
reduce_min, # noqa: F401
reduce_sum, # noqa: F401
reduce_abssum, # noqa: F401
)
from .customize import (
atomic_add, # noqa: F401
atomic_addx2, # noqa: F401
dp4a, # noqa: F401
)
def use_swizzle(panel_size: int, order: str = "row", enable: bool = True):
# If order is row, use rasterization2DRow, otherwise use rasterization2DColumn
# The panel size is the number of threads in a warp
# Use to improve the L2 Cache Locality
device_func = ("rasterization2DRow" if order == "row" else "rasterization2DColumn")
return T.attr(None, "threadblock_swizzle_pattern",
f"tl::{device_func}<{panel_size}>") if enable else None
def annotate_layout(layout_map):
# layout_map is a dictionary of buffer to layout
layout_map = {buffer.data: layout for buffer, layout in layout_map.items()}
return T.block_attr({"layout_map": layout_map})
def import_source(source: Optional[str] = None):
# source is the source code to be imported
return T.block_attr({"pragma_import_c": source}) if source is not None else None
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""The language interface for tl programs."""
from tvm.script import tir as T
def alloc_shared(shape, dtype, scope="shared.dyn"):
return T.alloc_buffer(shape, dtype, scope=scope)
def alloc_local(shape, dtype, scope="local"):
return T.alloc_buffer(shape, dtype, scope=scope)
def alloc_fragment(shape, dtype, scope="local.fragment"):
return T.alloc_buffer(shape, dtype, scope=scope)
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""The language interface for tl programs."""
from typing import Union, List, Optional
from tvm import tir
from tvm.script import tir as T
def region(buffer: tir.BufferLoad, access_type: str, *args: tir.PrimExpr):
access_type = {"r": 1, "w": 2, "rw": 3}[access_type]
return tir.call_intrin(
"handle", tir.op.Op.get("tl.region"), buffer, access_type, *args
)
def buffer_to_tile_region(buffer: tir.Buffer, access_type: str):
mins = [0 for _ in buffer.shape]
extents = [x for x in buffer.shape]
return region(T.BufferLoad(buffer, mins), access_type, *extents)
def buffer_load_to_tile_region(
load: tir.BufferLoad, access_type: str, extents: List[tir.PrimExpr]
):
return region(load, access_type, *extents)
def buffer_region_to_tile_region(
buffer_region: tir.BufferRegion, access_type: str
):
mins = [x.min for x in buffer_region.region]
extents = [x.extent for x in buffer_region.region]
return region(
T.BufferLoad(buffer_region.buffer, mins), access_type, *extents
)
def copy(
src: Union[tir.Buffer, tir.BufferLoad, tir.BufferRegion],
dst: Union[tir.Buffer, tir.BufferLoad],
coalesced_width: Optional[int] = None,
):
def get_extent(data):
if isinstance(data, tir.Buffer):
return data.shape
elif isinstance(data, tir.BufferRegion):
return [x.extent for x in data.region]
else:
return None
src_extent = get_extent(src)
dst_extent = get_extent(dst)
# if src_extent and dst_extent:
# ir.assert_structural_equal(src_extent, dst_extent)
if src_extent:
extent = src_extent
elif dst_extent:
extent = dst_extent
else:
raise TypeError("Can't deduce copy extents from args")
def _to_region(data, access_type):
if isinstance(data, tir.Buffer):
return buffer_to_tile_region(data, access_type)
elif isinstance(data, tir.BufferRegion):
return buffer_region_to_tile_region(data, access_type)
else:
return buffer_load_to_tile_region(data, access_type, extent)
src = _to_region(src, "r")
dst = _to_region(dst, "w")
if coalesced_width is not None:
return tir.call_intrin(
"handle", tir.op.Op.get("tl.copy"), src, dst, coalesced_width
)
else:
return tir.call_intrin("handle", tir.op.Op.get("tl.copy"), src, dst)
def c2d_im2col(
img: tir.Buffer,
col: tir.Buffer,
nhw_step: tir.PrimExpr,
c_step: tir.PrimExpr,
kernel: int,
stride: int,
dilation: int,
pad: int,
):
return tir.call_intrin(
"handle",
tir.op.Op.get("tl.c2d_im2col"),
img.access_ptr("r"),
col.access_ptr("w"),
nhw_step,
c_step,
kernel,
stride,
dilation,
pad,
)
class GemmWarpPolicy:
Square = 0
FullRow = 1
FullCol = 2
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""The language interface for tl programs."""
from tvm.script import tir as T
def atomic_add(dst, value):
return T.call_extern("handle", "atomicAdd", T.address_of(dst), value)
def atomic_addx2(dst, value):
return T.call_extern(
"handle", "atomicAddx2", T.address_of(dst), T.address_of(value)
)
def dp4a(A, B, C):
return T.call_extern(
"handle", "DP4A", T.address_of(A), T.address_of(B), T.address_of(C)
)
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""The language interface for tl programs."""
from tvm import tir
def fill(buffer: tir.Buffer, value: tir.PrimExpr):
buffer = buffer.access_ptr("w")
return tir.call_intrin("handle", tir.op.Op.get("tl.fill"), buffer, value)
def clear(buffer: tir.Buffer):
return fill(buffer, 0)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment