Commit 57ab687c authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Initialization] Migration of Codebase from Dev Branch into Main (#10)



* Add format.sh script for code formatting and linting

* docs update

* center align the title

* lint fix

* add ignore

* Add .gitignore for 3rdparty directory

* Add requirements-dev.txt, requirements-test.txt, and requirements.txt

* 3rdparty

* Add gemm.h, CMakeLists.txt, _ffi_api.py, __init__.py, runtime.h, reduce.h, loop_partition.h, utils.h, and loop_vectorize.h

* Refactor CMakeLists.txt and include statements

- Update CMakeLists.txt to use a newer version of CMake and add project name
- Remove unnecessary include directories

Fix include paths in layout.cc, codegen.cc, codegen.h, rt_mod.cc, frontend_legalize.cc, inject_pipeline.cc, layout_inference.cc, loop_vectorize.cc, and lower_tile_op.cc

- Update include paths to use relative paths instead of absolute paths

* Update submodule for 3rdparty/tvm

* update

* load dll first

* Refactor CMakeLists.txt and include statements

* Refactor CMakeLists.txt and include statements

* git keep update

* Refactor CMakeLists.txt and include statements

* Refactor CMakeLists.txt and include statements

* refactor code structure

* Update Readme

* CMakeLists Customized

* update readme

* update README

* update readme

* update usage

* with TVM_IMPORT_PYTHON_PATH to handle own tvm build python import

* annotate lower transform global func with `transform` prefix

* Migrate Simplify Pass from tilelang tvm branch

* enhance system environment handling with __init__ and CMake

* Initial commit

* CODE_OF_CONDUCT.md committed

* LICENSE committed

* README.md committed

* SECURITY.md committed

* SUPPORT.md committed

* CODE_OF_CONDUCT Commit

* LICENSE Commit

* SECURITY Commit

* SUPPORT Commit

* Modify Support

* Update README.md

* security ci update

* remove examples

* Update and implement clang-format

* add composable kernel components

* Migrate from latest update

* submodule update

* Test update

* Update License

* Spell check

* lint fix

* add clang-tidy to apply static analysis for c source

* update tilelang examples

* Update Install Docs

* Refactor filetree

* Enhance Install

* conflict resloved

* annotate_version

* Initial Update

* test fix

* install

* Implement setup.py

* lint fix

* Separate Init

* Separate test

* docker file commit

* add logo

* Update Readme and Examples

* update readme

* update logo

* Implement AMD Installation

* Add License

* Update AMD MI300x Benchmark

* update README

* update mi300 benchmark scripts

* update ignore

* enhance build scirpt

* update image

* enhance setup.py to remove duplicated libraries

* remove debug files

* update readme

* update image

* update gemm examples

* update flashattention README

* readme update

* add cmake into requirements

* libinfo fix

* auto update submodule

* lint fix

* Fix AMD Build and Test

* Update check for transpose attribute for CDNA Arch

* typo fix for amd

* Implement Matmul Benchmark

* Refactor Code

* [TypoFix] Fix GEMM Example

* [Docs] Init Linear Attention README

* [TYPO] Typo fix

* [Lint] Lint Fix

* enhance example with intrinsics

* [Enhancement] Improve Buffer Collection during IR Parser

* [Dev] Introduce Current classmethod to get current frame

* submodule update

* fake test pass update

* support thread_extent_api

* code optimize

* Add GEMM function implementation for matrix multiplication

* Update logging format to reflect TileLang in logger messages

* Refactor CMakeLists.txt for improved readability and set default build type to Release

* Support Gemm SS Primitives Implementation

* [README] Upload Tile Language Logo (#5)

* update logo

* Update README.md to enhance formatting and center the title

---------
Co-authored-by: default avatarmicrosoft-github-operations[bot] <55726097+microsoft-github-operations[bot]@users.noreply.github.com>
Co-authored-by: default avatarMicrosoft Open Source <microsoftopensource@users.noreply.github.com>
Co-authored-by: default avatarYu Cheng <yu.cheng@pku.edu.cn>
parent 64f17c2f
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""The language interface for tl programs."""
from tvm import tir
class GemmWarpPolicy:
Square = 0
FullRow = 1
FullCol = 2
def gemm(
A: tir.Buffer,
B: tir.Buffer,
C: tir.Buffer,
transpose_A: bool = False,
transpose_B: bool = False,
policy: GemmWarpPolicy = GemmWarpPolicy.Square,
k_pack: int = 1,
):
"""
k_pack: int
The number of k dimension that is packed into a single warp.
please ref to mfma macro generator for the detail information.
"""
M = C.shape[0]
N = C.shape[1]
K = A.shape[0] if transpose_A else A.shape[1]
K_B = B.shape[1] if transpose_B else B.shape[0]
assert K == K_B, "gemm K shape check failed"
Aptr = A.access_ptr("r")
Bptr = B.access_ptr("r")
Cptr = C.access_ptr("rw")
return tir.call_intrin(
"handle",
tir.op.Op.get("tl.gemm"),
Aptr,
Bptr,
Cptr,
transpose_A,
transpose_B,
M,
N,
K,
policy,
k_pack,
)
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""The language interface for tl programs."""
from typing import Union, List, Tuple, Optional
from collections import deque
from tvm import tir
from tvm.tir import Var
from tvm.script.ir_builder.tir.frame import TIRFrame
from tvm._ffi import register_object
from tilelang import _ffi_api
class FrameStack:
"""
A simple stack-like wrapper around a deque that provides
push, pop, and top methods for convenience.
"""
def __init__(self):
self._stack = deque()
def push(self, item):
"""Pushes an item onto the top of the stack."""
self._stack.append(item)
def pop(self):
"""
Pops and returns the top of the stack, or returns None
if the stack is empty.
"""
if self._stack:
return self._stack.pop()
raise IndexError(f"{self.__class__.__name__} is empty")
def top(self):
"""
Returns the item on the top of the stack without removing it,
or None if the stack is empty.
"""
if self._stack:
return self._stack[-1]
raise IndexError(f"{self.__class__.__name__} is empty")
def __len__(self):
"""Returns the number of items in the stack."""
return len(self._stack)
def __bool__(self):
"""
Allows truthy checks on the stack object itself,
e.g., 'if stack: ...'
"""
return bool(self._stack)
# Use our new FrameStack instead of a plain list or deque
_kernel_launch_frame_stack = FrameStack()
@register_object("tl.KernelLaunchFrame")
class KernelLaunchFrame(TIRFrame):
"""
KernelLaunchFrame is a custom TIRFrame that manages block/thread indices
and handles the entry and exit of the kernel launch scope.
"""
def __enter__(self) -> Union[Var, List[Var]]:
"""
Enters the KernelLaunchFrame scope and pushes this frame onto the stack.
Returns one Var if we detect exactly 5 frames (meaning there is a single
block dimension), or a list of Vars otherwise.
"""
super().__enter__()
_kernel_launch_frame_stack.push(self)
# If we have exactly 5 frames, return the single iter_var.var.
if len(self.frames) == 5:
return self.frames[0].iter_var.var
# Otherwise, return a list of iter_var.var objects (excluding the last 4 frames).
return [frame.iter_var.var for frame in self.frames[0:-4]]
def __exit__(self, ptype, value, trace):
"""
Exits the KernelLaunchFrame scope and pops this frame from the stack,
but only if it's indeed the topmost frame.
"""
# Check if this frame is the current top before popping.
if _kernel_launch_frame_stack.top() is self:
_kernel_launch_frame_stack.pop()
super().__exit__(ptype, value, trace)
@classmethod
def Current(cls) -> Optional["KernelLaunchFrame"]:
"""
Returns the topmost (current) KernelLaunchFrame from the stack if it exists,
or None if the stack is empty.
"""
return _kernel_launch_frame_stack.top()
def get_block_extent(self, dim: int) -> int:
"""
Returns the block extent for the given dimension.
dim=0 corresponds to blockIdx.x, dim=1 to blockIdx.y, and dim=2 to blockIdx.z.
"""
iter_var = self.frames[dim].iter_var
return int(iter_var.dom.extent)
def get_thread_extent(self, dim: int) -> int:
"""
Returns the thread extent for the given dimension.
dim=0 corresponds to threadIdx.x, dim=1 to threadIdx.y, and dim=2 to threadIdx.z.
"""
iter_var = self.frames[-4 + dim].iter_var
return int(iter_var.dom.extent)
def get_num_threads(self) -> int:
"""
Returns the thread indices from the topmost frame.
"""
num_threads: int = 1
for thread_dim in range(3):
num_threads *= self.get_thread_extent(thread_dim)
return num_threads
@property
def blocks(self) -> List[Var]:
"""
Returns the block indices from the topmost frame.
"""
return [frame.iter_var.var for frame in self.frames[0:-4]]
@property
def threads(self) -> List[Var]:
"""
Returns the thread indices from the topmost frame.
"""
return [frame.iter_var.var for frame in self.frames[-4:]]
@property
def num_threads(self) -> int:
"""
Returns the total number of threads.
"""
return self.get_num_threads()
def Kernel(
*blocks: List[tir.PrimExpr],
threads: Union[int, List[int], Tuple] = 128,
prelude: Optional[str] = None,
):
"""Tools to quickly construct a GPU kernel launch frame.
Parameters
----------
blocks : List[int]
A list of extent, can be 1-3 dimension, representing gridDim.(x|y|z)
threads : int
A integer representing blockDim.x
Or a list of integers representing blockDim.(x|y|z)
if the value is -1, we skip the threadIdx.x binding.
prelude : str
The import c code of the kernel,
will be injected before the generated kernel code.
layout_annotation: Optional[Map[tir.Buffer, tir.IndexMap]]
The layout annotation map, used to annotate the layout of the buffers.
Returns
-------
res : Tuple[frame.LaunchThreadFrame]
The result LaunchThreadFrame.
"""
attrs: dict = {}
if isinstance(threads, int):
threads = [threads, 1, 1]
elif isinstance(threads, list):
threads = threads + [1] * (3 - len(threads))
elif isinstance(threads, tuple):
threads = list(threads) + [1] * (3 - len(threads))
else:
raise ValueError("threads must be an integer or a list of integers")
if prelude is not None:
attrs["pragma_import_c"] = prelude
return _ffi_api.KernelLaunch(blocks, threads, attrs)
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""The language interface for tl programs."""
from typing import Optional, Dict, Any
from tvm import tir
from tilelang import _ffi_api
def Parallel(*extents: tir.PrimExpr, coalesced_width: Optional[int] = None):
"""Tools to construct nested parallel for loop.
This can be used to create element-wise tensor expression.
Parameters
----------
extents : PrimExpr
The extents of the iteration.
coalesced_width : Optional[int]
The coalesced width of the parallel loop.
Returns
-------
res : frame.ForFrame
The ForFrame.
"""
annotations: Dict[str, Any] = {}
if coalesced_width is not None:
annotations.update({"coalesced_width": coalesced_width})
return _ffi_api.Parallel(extents, annotations) # type: ignore[attr-defined] # pylint: disable=no-member
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""The language interface for tl programs."""
from typing import List, Optional
from tvm import tir
from tvm.tir import IntImm
from tilelang import _ffi_api
def Pipelined(
start: tir.PrimExpr,
stop: tir.PrimExpr = None,
num_stages: int = 0,
order: Optional[List[int]] = None,
stage: Optional[List[int]] = None,
sync: Optional[List[List[int]]] = None,
group: Optional[List[List[int]]] = None,
):
"""Tools to construct pipelined for loop.
Parameters
----------
start : PrimExpr
The minimum value of iteration.
stop : PrimExpr
The maximum value of iteration.
num_stages : int
The max number of buffer used between pipeline producers and consumers.
if num_stages is 0, pipeline will not be enabled.
Returns
-------
res : frame.ForFrame
The ForFrame.
"""
if stop is None:
stop = start
start = IntImm(start.dtype, 0) if hasattr(start, "dtype") else 0
if order is None:
order = []
if stage is None:
stage = []
if sync is None:
sync = []
if group is None:
group = []
# type: ignore[attr-defined] # pylint: disable=no-member
return _ffi_api.Pipelined(
start, stop, num_stages, order, stage, sync, group
)
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""The language interface for tl programs."""
from tvm import tir
def reduce(
buffer: tir.Buffer, out: tir.Buffer, reduce_type: str, dim: int, clear: bool
):
buffer = buffer.access_ptr("r")
out = out.access_ptr("w")
return tir.call_intrin(
"handle",
tir.op.Op.get("tl.reduce"),
buffer,
out,
reduce_type,
dim,
clear,
)
def reduce_max(
buffer: tir.Buffer, out: tir.Buffer, dim: int, clear: bool = True
):
"""Perform reduce max on input buffer, store the result to output buffer
Parameters
----------
buffer : Buffer
The input buffer.
out : Buffer
The output buffer.
dim : int
The dimension to perform reduce on
clear : bool
If set to False, the output buffer will first be initialized to -inf.
Returns
-------
handle : PrimExpr
"""
return reduce(buffer, out, "max", dim, clear)
def reduce_min(
buffer: tir.Buffer, out: tir.Buffer, dim: int, clear: bool = True
):
return reduce(buffer, out, "min", dim, clear)
def reduce_sum(buffer: tir.Buffer, out: tir.Buffer, dim: int):
return reduce(buffer, out, "sum", dim, True)
def reduce_abssum(buffer: tir.Buffer, out: tir.Buffer, dim: int):
return reduce(buffer, out, "abssum", dim, True)
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Wrapping Layouts."""
# pylint: disable=invalid-name, unsupported-binary-operation
from .layout import Layout # noqa: F401
from .fragment import Fragment # noqa: F401
from .swizzle import make_swizzled_layout # noqa: F401
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Wrapping Layouts."""
# pylint: disable=invalid-name, unsupported-binary-operation
import tvm
from tvm.ir import Range
from tvm.tir import IterVar, Var
from tilelang import _ffi_api
from tilelang.layout import Layout
@tvm._ffi.register_object("tl.Fragment")
class Fragment(Layout):
# pylint: disable=super-init-not-called
def __init__(self, shape, forward_thread_fn, replicate=1, forward_index_fn=None):
forward_vars = []
for idx, size in enumerate(shape):
iv = IterVar(Range(0, size), Var(f"i{idx}", "int32"), 0)
forward_vars.append(iv)
vars = [iv.var for iv in forward_vars]
forward_index = forward_index_fn(*vars) if forward_index_fn else None
if not isinstance(forward_index, tvm.ir.container.Array):
forward_index = [forward_index]
if replicate > 1:
thread_replicate = IterVar(Range(0, replicate), Var("rep", "int32"), 0)
forward_thread = forward_thread_fn(*vars, thread_replicate.var)
else:
thread_replicate = None
forward_thread = forward_thread_fn(*vars)
self.__init_handle_by_constructor__(
_ffi_api.Fragment,
forward_vars,
forward_index,
forward_thread,
thread_replicate,
)
@property
def thread(self):
return _ffi_api.Fragment_thread(self)
def get_thread_size(self):
return _ffi_api.Fragment_thread_size(self)
def repeat(self, repeats, repeat_on_thread: bool = False) -> "Fragment":
return _ffi_api.Fragment_repeat(self, repeats, repeat_on_thread)
def condense_rep_var(self) -> "Fragment":
return _ffi_api.Fragment_condense_rep_var(self)
def make_swizzled_layout(buffer: tvm.tir.Buffer):
assert len(buffer.shape) == 2
return _ffi_api.make_swizzled_layout(
int(buffer.shape[0]),
int(buffer.shape[1]),
int(tvm.DataType(buffer.dtype).bits),
)
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Wrapping Layouts."""
# pylint: disable=invalid-name, unsupported-binary-operation
import tvm
from tvm.ir import Node, Range
from tvm.tir import IterVar, Var, PrimExpr
from tilelang import _ffi_api
@tvm._ffi.register_object("tl.Layout")
class Layout(Node):
def __init__(self, shape, forward_fn):
forward_vars = []
for idx, size in enumerate(shape):
iv = IterVar(Range(0, size), Var(f"i{idx}", "int32"), 0)
forward_vars.append(iv)
vars = [iv.var for iv in forward_vars]
forward_index = forward_fn(*vars)
if isinstance(forward_index, PrimExpr):
forward_index = [forward_index]
self.__init_handle_by_constructor__(_ffi_api.Layout, forward_vars, forward_index)
@property
def index(self):
return _ffi_api.Layout_index(self)
def get_input_shape(self):
return _ffi_api.Layout_input_shape(self)
def get_output_shape(self):
return _ffi_api.Layout_output_shape(self)
def inverse(self) -> "Layout":
return _ffi_api.Layout_inverse(self)
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Wrapping Layouts."""
# pylint: disable=invalid-name, unsupported-binary-operation
import tvm
from tilelang import _ffi_api
def make_swizzled_layout(buffer: tvm.tir.Buffer):
assert len(buffer.shape) == 2
return _ffi_api.make_swizzled_layout(
int(buffer.shape[0]),
int(buffer.shape[1]),
int(tvm.DataType(buffer.dtype).bits),
)
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Library information. This is a standalone file that can be used to get various info.
Modified from: https://github.com/mlc-ai/mlc-llm/blob/main/python/mlc_llm/libinfo.py
"""
#! pylint: disable=protected-access
import os
import sys
TILELANG_LIBRARY_PATH = os.environ.get("TILELANG_LIBRARY_PATH", None)
def get_env_paths(env_var, splitter):
"""Get path in env variable"""
if os.environ.get(env_var, None):
return [p.strip() for p in os.environ[env_var].split(splitter)]
return []
def get_dll_directories():
"""Get extra tile lang dll directories"""
curr_dir = os.path.dirname(os.path.realpath(os.path.expanduser(__file__)))
source_dir = os.path.abspath(os.path.join(curr_dir, ".."))
dll_path = [
curr_dir,
os.path.join(curr_dir, "lib"), # pypi build
os.path.join(source_dir, "build"), # local build
os.path.join(source_dir, "build", "Release"),
]
if TILELANG_LIBRARY_PATH:
dll_path.append(TILELANG_LIBRARY_PATH)
if "CONDA_PREFIX" in os.environ:
dll_path.append(os.path.join(os.environ["CONDA_PREFIX"], "lib"))
if sys.platform.startswith("linux") or sys.platform.startswith("freebsd"):
dll_path.extend(get_env_paths("LD_LIBRARY_PATH", ":"))
elif sys.platform.startswith("darwin"):
dll_path.extend(get_env_paths("DYLD_LIBRARY_PATH", ":"))
elif sys.platform.startswith("win32"):
dll_path.extend(get_env_paths("PATH", ";"))
return [os.path.abspath(p) for p in dll_path if os.path.isdir(p)]
def find_lib_path(name, optional=False):
"""Find tile lang library
Parameters
----------
name : str
The name of the library
optional: boolean
Whether the library is required
"""
if sys.platform.startswith("linux") or sys.platform.startswith("freebsd"):
lib_name = f"lib{name}.so"
elif sys.platform.startswith("win32"):
lib_name = f"{name}.dll"
elif sys.platform.startswith("darwin"):
lib_name = f"lib{name}.dylib"
else:
lib_name = f"lib{name}.so"
dll_paths = get_dll_directories()
lib_dll_path = [os.path.join(p, lib_name) for p in dll_paths]
lib_found = [p for p in lib_dll_path if os.path.exists(p) and os.path.isfile(p)]
if not lib_found and not optional:
message = (f"Cannot find libraries: {lib_name}\n" + "List of candidates:\n" +
"\n".join(lib_dll_path))
raise RuntimeError(message)
return lib_found
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
""" bootstrap the primitives module via tile language """
from .gemm import gemm # noqa: F401
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from typing import Optional
from tvm import tir
from tilelang.primitives.utils import is_local, is_fragment, is_shared
from tilelang.primitives.gemm.base import GemmWarpPolicy
from tilelang.primitives.gemm.gemm_mma import (
GemmPrimitiveMMA,
)
def gemm(
A: tir.Buffer,
B: tir.Buffer,
C: tir.Buffer,
transpose_A: bool = False,
transpose_B: bool = False,
block_row_warps: Optional[int] = None,
block_col_warps: Optional[int] = None,
warp_row_tiles: Optional[int] = None,
warp_col_tiles: Optional[int] = None,
chunk: Optional[int] = None,
policy: GemmWarpPolicy = GemmWarpPolicy.Square,
k_pack: int = 1,
):
assert is_local(A) or is_fragment(A) or is_shared(A), (
f"Expected A to be a local, fragment, or shared buffer, but got {A.scope()}"
)
assert is_local(B) or is_fragment(B) or is_shared(B), (
f"Expected B to be a local, fragment, or shared buffer, but got {B.scope()}"
)
assert is_local(C) or is_fragment(C), (
f"Expected C to be a local, fragment, but got {C.scope()}"
)
# TODO(lei): Now we only support Nvidia GPUs
# Must enhance the design to implement runtime lowering
# for different targets (hip mfma for example)
return GemmPrimitiveMMA(
A=A,
B=B,
C=C,
transpose_A=transpose_A,
transpose_B=transpose_B,
block_row_warps=block_row_warps,
block_col_warps=block_col_warps,
warp_row_tiles=warp_row_tiles,
warp_col_tiles=warp_col_tiles,
chunk=chunk,
policy=policy,
k_pack=k_pack,
).invoke()
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from enum import IntEnum
from dataclasses import dataclass
from typing import Optional
from tvm import tir
class GemmWarpPolicy(IntEnum):
"""
Enumeration for GEMM Warp Partitioning Policies.
"""
Square = 0 # Balance warps evenly in a "square" aspect ratio.
FullRow = 1 # Assign all warps to rows.
FullCol = 2 # Assign all warps to columns.
def is_square(self) -> bool:
"""
Check if the policy is a square partitioning.
Returns:
bool: True if the policy is square, False otherwise.
"""
return self == GemmWarpPolicy.Square
def is_full_row(self) -> bool:
"""
Check if the policy is a full row partitioning.
Returns:
bool: True if the policy is full row, False otherwise.
"""
return self == GemmWarpPolicy.FullRow
def is_full_col(self) -> bool:
"""
Check if the policy is a full column partitioning.
Returns:
bool: True if the policy is full column, False otherwise.
"""
return self == GemmWarpPolicy.FullCol
@staticmethod
def to_prime_factors(num):
"""
Compute the prime factorization of a given number.
Args:
num (int): The number to factorize.
Returns:
list: A list of prime factors of the number.
"""
factors = []
i = 2
# Find all prime factors up to the square root of the number.
while i * i <= num:
while num % i == 0: # Check divisibility by `i`.
factors.append(i)
num //= i
i += 1
# If the remaining number is greater than 1, it's a prime factor.
if num > 1:
factors.append(num)
return factors
def compute_warp_partition(self, M, N, num_warps):
"""
Compute the warp partition (m_warp, n_warp) based on the given policy.
Args:
M (int): The number of rows in the GEMM workload.
N (int): The number of columns in the GEMM workload.
num_warps (int): The total number of warps available.
Returns:
tuple: A tuple (m_warp, n_warp) representing the partitioning of warps.
Raises:
ValueError: If the policy is invalid or the partitioning fails.
AssertionError: If M or N is not divisible by the required factor for FullRow or FullCol policies.
"""
m_warp = 1 # Initial warp count for rows.
n_warp = 1 # Initial warp count for columns.
if self.is_full_row():
# FullRow policy: Allocate all warps to rows.
m_warp = num_warps
assert (
M % num_warps == 0
), "M must be divisible by num_warps for FullRow policy"
elif self.is_full_col():
# FullCol policy: Allocate all warps to columns.
n_warp = num_warps
assert (
N % num_warps == 0
), "N must be divisible by num_warps for FullCol policy"
elif self.is_square():
# Square policy: Try to balance warps across rows and columns.
factors = self.to_prime_factors(num_warps)
for factor in factors:
M_divisible = (M % (factor * m_warp)) == 0
N_divisible = (N % (factor * n_warp)) == 0
# Assign the factor to either m_warp or n_warp based on divisibility and aspect ratio.
if M_divisible and N_divisible:
# Prefer to assign to rows if M is larger, otherwise to columns.
if M / m_warp >= N / n_warp:
m_warp *= factor
else:
n_warp *= factor
elif M_divisible:
m_warp *= factor
elif N_divisible:
n_warp *= factor
else:
# If no divisibility condition is met, raise an error.
raise ValueError(
f"Cannot compute warp partition for shape {M} x {N} with num_warps {num_warps}"
)
else:
# Raise an error for unknown policies.
raise ValueError(f"Unknown GemmWarpPolicy: {self}")
return m_warp, n_warp
@dataclass
class GemmBaseParams:
# OP Related Config
A: tir.Buffer
B: tir.Buffer
C: tir.Buffer
transpose_A: bool = False
transpose_B: bool = False
block_row_warps: Optional[int] = None
block_col_warps: Optional[int] = None
warp_row_tiles: Optional[int] = None
warp_col_tiles: Optional[int] = None
chunk: Optional[int] = None
policy: GemmWarpPolicy = GemmWarpPolicy.Square,
k_pack: int = 1
def get_warp_size(self) -> int:
# must rewrite to 64 if the target
# is cdna mfma
return 32
def params_as_dict(self):
return {
"A": self.A,
"B": self.B,
"C": self.C,
"transpose_A": self.transpose_A,
"transpose_B": self.transpose_B,
"block_row_warps": self.block_row_warps,
"block_col_warps": self.block_col_warps,
"warp_row_tiles": self.warp_row_tiles,
"warp_col_tiles": self.warp_col_tiles,
"chunk": self.chunk,
"policy": self.policy,
"k_pack": self.k_pack,
}
def infer_block_partition(self, threads: Optional[int]) -> None:
"""
Infer and set block partition parameters (e.g., block_row_warps,
block_col_warps, warp_row_tiles, warp_col_tiles, chunk) based on the
shape of A and B. If these parameters are not already specified, the
method will attempt to infer them automatically based on the given
`threads`.
Parameters
----------
threads : Optional[int]
The total number of threads in a block. Must be provided
if any block partition parameter is not already set.
Raises
------
AssertionError
If `threads` is None but any block partition parameter is missing,
or if A and B have inconsistent shapes for GEMM.
"""
warp_size = self.get_warp_size()
A, B = self.A, self.B
transpose_A, transpose_B = self.transpose_A, self.transpose_B
block_row_warps, block_col_warps = (
self.block_row_warps,
self.block_col_warps,
)
warp_row_tiles, warp_col_tiles = (
self.warp_row_tiles,
self.warp_col_tiles,
)
policy = self.policy
# The field `chunk` is not declared in GemmBaseParams by default.
# We infer it based on the K dimension of matrices.
# Initialize chunk from `self` if it exists; otherwise we infer it.
chunk = getattr(self, "chunk", None)
# Determine whether block partition parameters need to be inferred
require_infer = (
block_row_warps is None
or block_col_warps is None
or warp_row_tiles is None
or warp_col_tiles is None
or chunk is None
)
A_shape, B_shape = A.shape, B.shape
if require_infer:
assert (
threads is not None
), "threads must be provided for auto inference"
# Auto-inference only supports 2D matrix multiplication
assert (
len(A_shape) == 2 and len(B_shape) == 2
), f"Only support 2D matrix multiplication, got {len(A_shape)}D and {len(B_shape)}D"
# Analyze A/B shapes
AM = A_shape[1] if transpose_A else A_shape[0] # M dimension
BN = B_shape[0] if transpose_B else B_shape[1] # N dimension
AK = A_shape[0] if transpose_A else A_shape[1] # K dimension
BK = B_shape[1] if transpose_B else B_shape[0] # K dimension
assert AK == BK, "A and B shape mismatch"
block_M = int(AM)
block_N = int(BN)
num_warps = threads // warp_size
# Infer block partition using a user-specified policy
block_row_warps, block_col_warps = policy.compute_warp_partition(
block_M, block_N, num_warps
)
warp_row_tiles = block_M // block_row_warps
warp_col_tiles = block_N // block_col_warps
chunk = int(AK)
# rewrite the values
self.block_row_warps = block_row_warps
self.block_col_warps = block_col_warps
self.warp_row_tiles = warp_row_tiles
self.warp_col_tiles = warp_col_tiles
self.chunk = chunk
@property
def class_attributes(self):
return self.params_as_dict()
def __repr__(self) -> str:
cls_name = self.__class__.__name__
fields = self.class_attributes
field_str = ", ".join(
f"{key}={value!r}" for key, value in fields.items()
)
return f"{cls_name}({field_str})"
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations
from typing import Optional, Dict
from dataclasses import dataclass
from tvm import tir
import tilelang.language as T
from tilelang.primitives.utils import is_fragment, array_reduce
from tilelang.primitives.gemm.base import GemmBaseParams
from tilelang.intrinsics.mma_macro_generator import TensorCoreIntrinEmitter
# TODO(lei): Implement GEMM_SR, GEMM_RS, GEMM_RR
@dataclass
class GemmPrimitiveMMA(GemmBaseParams):
"""
A GEMM (General Matrix Multiply) primitive that uses Tensor Core MMA (Matrix
Multiply and Accumulate) instructions. Inherits from GemmBaseParams which
provides basic parameters such as A, B, C buffers and transposition flags.
"""
def gemm_rrr(
self,
A: tir.Buffer,
B: tir.Buffer,
C: tir.Buffer,
mma_emitter: TensorCoreIntrinEmitter,
) -> tir.PrimExpr:
raise NotImplementedError("GEMM_RRR is not implemented yet")
def gemm_rsr(
self,
A: tir.Buffer,
B: tir.Buffer,
C: tir.Buffer,
mma_emitter: TensorCoreIntrinEmitter,
)-> tir.PrimExpr:
in_dtype = self.in_dtype
warp_rows = mma_emitter.warp_rows
warp_cols = mma_emitter.warp_cols
local_size_a = mma_emitter.local_size_a
local_size_b = mma_emitter.local_size_b
block_K = mma_emitter.chunk
micro_size_k = mma_emitter.micro_size_k
threads = mma_emitter.threads
# Check if C is a fragment for applying custom layout
a_is_fragment = is_fragment(A)
c_is_fragment = is_fragment(C)
@T.macro
def _gemm_rsr(
A_local: tir.Buffer, B_shared: tir.Buffer, C_local: tir.Buffer
) -> None:
"""
The inner macro that loads data from shared buffers A_shared and
B_shared into local fragments, then issues Tensor Core mma ops,
accumulating into C_local.
"""
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
thread_bindings = T.thread_binding(0, threads, "threadIdx.x")
if a_is_fragment:
# Annotate layout for A_local if it is a fragment.
T.annotate_layout(
{
A_local: mma_emitter.make_mma_load_layout(A_local, "A"),
}
)
if c_is_fragment:
# Annotate layout for C_local if it is a fragment.
T.annotate_layout(
{
C_local: mma_emitter.make_mma_store_layout(C_local),
}
)
for ki in T.serial(0, (block_K // micro_size_k)):
# Load B into fragment
mma_emitter.ldmatrix_b(
B_local,
B_shared,
ki,
thread_bindings=thread_bindings,
)
# Perform Matrix Multiplication
mma_emitter.mma(
A_local,
B_local,
C_local,
ki,
)
return _gemm_rsr(A, B, C)
def gemm_srr(
self,
A: tir.Buffer,
B: tir.Buffer,
C: tir.Buffer,
mma_emitter: TensorCoreIntrinEmitter,
)-> tir.PrimExpr:
raise NotImplementedError("GEMM_RSR is not implemented yet")
def gemm_ssr(
self,
A: tir.Buffer,
B: tir.Buffer,
C: tir.Buffer,
mma_emitter: TensorCoreIntrinEmitter,
) -> tir.PrimExpr:
"""
Perform a single-step reduction (SSR) GEMM using Tensor Core MMA
primitives. Loads fragments of A and B from shared memory, multiplies
them, and accumulates into C.
Parameters
----------
A : tir.Buffer
The buffer for matrix A (in shared memory).
B : tir.Buffer
The buffer for matrix B (in shared memory).
C : tir.Buffer
The buffer for the accumulation results.
mma_emitter : TensorCoreIntrinEmitter
A helper object responsible for generating Tensor Core MMA
instructions (ldmatrix, mma, etc.).
Returns
-------
tir.PrimExpr
The generated IR expression (macro) representing the GEMM loop.
"""
in_dtype = self.in_dtype
warp_rows = mma_emitter.warp_rows
warp_cols = mma_emitter.warp_cols
local_size_a = mma_emitter.local_size_a
local_size_b = mma_emitter.local_size_b
block_K = mma_emitter.chunk
micro_size_k = mma_emitter.micro_size_k
threads = mma_emitter.threads
# Check if C is a fragment for applying custom layout
c_is_fragment = is_fragment(C)
@T.macro
def _gemm_ssr(
A_shared: tir.Buffer, B_shared: tir.Buffer, C_local: tir.Buffer
) -> None:
"""
The inner macro that loads data from shared buffers A_shared and
B_shared into local fragments, then issues Tensor Core mma ops,
accumulating into C_local.
"""
A_local = T.alloc_local((warp_rows * local_size_a), in_dtype)
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
thread_bindings = T.thread_binding(0, threads, "threadIdx.x")
if c_is_fragment:
# Annotate layout for C_local if it is a fragment.
T.annotate_layout(
{
C_local: mma_emitter.make_mma_store_layout(
C_local
),
}
)
for ki in T.serial(0, (block_K // micro_size_k)):
# Load A into fragment
mma_emitter.ldmatrix_a(
A_local,
A_shared,
ki,
thread_bindings=thread_bindings,
)
# Load B into fragment
mma_emitter.ldmatrix_b(
B_local,
B_shared,
ki,
thread_bindings=thread_bindings,
)
# Perform Matrix Multiplication
mma_emitter.mma(A_local, B_local, C_local)
return _gemm_ssr(A, B, C)
def invoke(self) -> tir.PrimExpr:
"""
Entry point to generate a GEMM SSR (single-step reduction) with Tensor
Core instructions. Performs the following steps:
1. Infers block partition parameters if necessary.
2. Creates a `TensorCoreIntrinEmitter` with the correct data types
and dimensions.
3. Invokes the GEMM SSR function to generate the final IR expression.
Returns
-------
tir.PrimExpr
The generated GEMM IR expression.
"""
# Infer block partition if necessary
current_frame = T.kernel.KernelLaunchFrame.Current()
threads = current_frame.num_threads
self.infer_block_partition(threads)
A, B, C = self.A, self.B, self.C
transpose_A, transpose_B = self.transpose_A, self.transpose_B
block_row_warps, block_col_warps = (
self.block_row_warps,
self.block_col_warps,
)
warp_row_tiles, warp_col_tiles = (
self.warp_row_tiles,
self.warp_col_tiles,
)
chunk = self.chunk
# Check dtypes
A_dtype, B_dtype, C_dtype = A.dtype, B.dtype, C.dtype
assert A_dtype == B_dtype, "A and B must have the same dtype"
in_dtype, accum_dtype = A_dtype, C_dtype
# Create the MMA emitter
mma_emitter = TensorCoreIntrinEmitter(
a_dtype=in_dtype,
b_dtype=in_dtype,
accum_dtype=accum_dtype,
a_transposed=transpose_A,
b_transposed=transpose_B,
block_row_warps=block_row_warps,
block_col_warps=block_col_warps,
warp_row_tiles=warp_row_tiles,
warp_col_tiles=warp_col_tiles,
chunk=chunk,
)
a_is_fragment = is_fragment(A)
b_is_fragment = is_fragment(B)
if a_is_fragment and b_is_fragment:
return self.gemm_rrr(A, B, C, mma_emitter)
if a_is_fragment:
return self.gemm_rsr(A, B, C, mma_emitter)
if b_is_fragment:
return self.gemm_srr(A, B, C, mma_emitter)
return self.gemm_ssr(A, B, C, mma_emitter)
@property
def in_dtype(self) -> str:
"""
Returns
-------
str
The input data type for A and B. Assumes both have the same dtype.
Raises
------
AssertionError
If A and B do not share the same dtype.
"""
A_dtype, B_dtype = self.A.dtype, self.B.dtype
assert A_dtype == B_dtype, "A and B must have the same dtype"
return self.A.dtype
@property
def accum_dtype(self) -> str:
"""
Returns
-------
str
The accumulation data type for C.
"""
return self.C.dtype
__all__ = ["GemmPrimitiveMMA"]
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from tvm.tir import Buffer
from typing import List
from functools import reduce
# Scope Checkers for TVM Buffers
# These utility functions check the memory scope of a given TVM buffer.
def is_global(buffer: Buffer) -> bool:
"""
Check if the buffer is in the global memory scope.
Args:
buffer (Buffer): The TVM buffer to check.
Returns:
bool: True if the buffer is in global memory, False otherwise.
"""
return buffer.scope() == "global"
def is_shared(buffer: Buffer, allow_dynamic: bool = True) -> bool:
"""
Check if the buffer is in the shared memory scope.
Args:
buffer (Buffer): The TVM buffer to check.
Returns:
bool: True if the buffer is in shared memory, False otherwise.
"""
conditions = [False]
conditions.append(buffer.scope() == "shared")
if allow_dynamic:
conditions.append(is_shared_dynamic(buffer))
return any(conditions)
def is_shared_dynamic(buffer: Buffer) -> bool:
"""
Check if the buffer is in the dynamic shared memory scope.
Args:
buffer (Buffer): The TVM buffer to check.
Returns:
bool: True if the buffer is in dynamic shared memory, False otherwise.
"""
return buffer.scope() == "shared.dyn"
def is_local(buffer: Buffer) -> bool:
"""
Check if the buffer is in the local memory scope.
Args:
buffer (Buffer): The TVM buffer to check.
Returns:
bool: True if the buffer is in local memory, False otherwise.
"""
return buffer.scope() == "local"
def is_fragment(buffer: Buffer) -> bool:
"""
Check if the buffer is a fragment (e.g., for matrix multiplication operations).
Args:
buffer (Buffer): The TVM buffer to check.
Returns:
bool: True if the buffer is a fragment, False otherwise.
"""
return buffer.scope().startswith("local.fragment")
def array_reduce(array: List[int]) -> int:
"""
Reduce an array of integers to a single integer.
Args:
array (List[int]): The array of integers to reduce.
Returns:
int: The reduced integer.
"""
return reduce(lambda x, y: x * y, array)
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import sys
import inspect
import pytest
from tvm.testing.utils import *
# pytest.main() wrapper to allow running single test file
def main():
test_file = inspect.getsourcefile(sys._getframe(1))
sys.exit(pytest.main([test_file] + sys.argv[1:]))
def torch_assert_close(tensor_a,
tensor_b,
rtol=1e-2,
atol=1e-3,
max_mismatched_ratio=0.001,
verbose=False):
"""
Custom function to assert that two tensors are "close enough," allowing a specified
percentage of mismatched elements.
Parameters:
----------
tensor_a : torch.Tensor
The first tensor to compare.
tensor_b : torch.Tensor
The second tensor to compare.
rtol : float, optional
Relative tolerance for comparison. Default is 1e-2.
atol : float, optional
Absolute tolerance for comparison. Default is 1e-3.
max_mismatched_ratio : float, optional
Maximum ratio of mismatched elements allowed (relative to the total number of elements).
Default is 0.001 (0.1% of total elements).
Raises:
-------
AssertionError:
If the ratio of mismatched elements exceeds `max_mismatched_ratio`.
"""
import torch
# Compute the absolute difference between the two tensors
diff = torch.abs(tensor_a - tensor_b)
# Compute the maximum allowable difference for each element
max_diff = atol + rtol * torch.abs(tensor_b)
# Identify elements where the difference exceeds the maximum allowable difference
mismatched = diff > max_diff
# Count the number of mismatched elements
num_mismatched = mismatched.sum().item()
# Calculate the total number of elements in the tensor
total_elements = tensor_a.numel()
# Compute the allowed mismatched elements based on the ratio
max_allowed_mismatched = int(total_elements * max_mismatched_ratio)
# Print debug information about the mismatch
if verbose:
print(f"Number of mismatched elements: {num_mismatched} / {total_elements} "
f"(allowed: {max_allowed_mismatched})")
# Check if the number of mismatched elements exceeds the allowed threshold
if num_mismatched > max_allowed_mismatched:
raise AssertionError(
f"Too many mismatched elements: {num_mismatched} > {max_allowed_mismatched} "
f"({max_mismatched_ratio * 100:.2f}% allowed, but get {num_mismatched / total_elements * 100:.2f}%). "
f"Greatest absolute difference: {diff.max().item()}, "
f"Greatest relative difference: {(diff / (torch.abs(tensor_b) + 1e-12)).max().item()}.")
else:
return True
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Wrapping transformations."""
# pylint: disable=invalid-name, unsupported-binary-operation
from . import _ffi_api
from .simplify import Simplify, simplify_prim_func # noqa: F401
def ClusterPlanning():
"""ClusterPlanning
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.ClusterPlanning() # type: ignore
def PipelinePlanning():
"""infer the fragment/shared memory layout
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.PipelinePlanning() # type: ignore
def LayoutInference():
"""LayoutInference
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.LayoutInference() # type: ignore
def LowerTileOp():
"""LowerTileOp
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.LowerTileOp() # type: ignore
def InjectSoftwarePipeline():
"""InjectSoftwarePipeline
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.InjectSoftwarePipeline() # type: ignore
def FrontendLegalize():
"""FrontendLegalize
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.FrontendLegalize() # type: ignore
def LowerHopperIntrin():
"""LowerHopperIntrin
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.LowerHopperIntrin() # type: ignore
def WarpSpecializedPipeline():
"""WarpSpecializedPipeline
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.WarpSpecializedPipeline() # type: ignore
def ThreadPartialSync(storage_scope: str):
"""Insert partial sync.
Parameters
----------
storage_scope: str
The target storage scope.
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.ThreadPartialSync(storage_scope) # type: ignore
def MultiVersionBuffer():
"""WarpSpecializedPipeline
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.MultiVersionBuffer() # type: ignore
def WarpSpecialized():
"""WarpSpecializedPipeline
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.WarpSpecialized() # type: ignore
def InjectFenceProxy():
"""InjectFenceProxy
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.InjectFenceProxy() # type: ignore
def LegalizeVectorizedLoop():
"""LegalizeLoopVectorize
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.LegalizeVectorizedLoop() # type: ignore
def LegalizeSafeMemoryAccess():
"""LegalizeLoopVectorize
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.LegalizeSafeMemoryAccess() # type: ignore
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""FFI APIs for tilelang"""
import tvm._ffi
# TVM_REGISTER_GLOBAL("tl.name").set_body_typed(func);
tvm._ffi._init_api("tl.transform", __name__) # pylint: disable=protected-access
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from tilelang import tvm as tvm
from tvm import IRModule
from tvm.tir import PrimFunc
from typing import Union, Callable
from . import _ffi_api
def Simplify():
"""Simplify
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.Simplify() # type: ignore
def _Simplify(stmt: Union[PrimFunc, IRModule]) -> Union[PrimFunc, IRModule]:
if isinstance(stmt, PrimFunc):
mod = Simplify()(IRModule.from_expr(stmt))
assert len(mod.functions) == 1, "Simplify should return a single function"
return list(mod.functions.values()).pop()
elif isinstance(stmt, IRModule):
return Simplify()(stmt)
else:
raise ValueError(f"Unsupported type: {type(stmt)}")
# Decorator to simplify the output of a function
def simplify_prim_func(func: Callable) -> Callable:
def wrapper(*args, **kwargs):
stmt: Union[PrimFunc, IRModule] = (func)(*args, **kwargs)
return _Simplify(stmt)
return wrapper
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""The profiler and convert to torch utils"""
from .target import determine_target # noqa: F401
from .profiler import Profiler # noqa: F401
from .tensor import TensorSupplyType, torch_assert_close # noqa: F401
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment