"include/ck/utility/math.hpp" did not exist on "bbcb67d0aac81b51336981713662a726875ebd58"
Unverified Commit 29051439 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Lint] Phaseout Yapf format and embrace ruff format (#1417)

parent e84b24bc
......@@ -16,12 +16,7 @@ from tvm.base import py_str
from tvm.contrib.rocm import get_rocm_arch, find_rocm_path
def compile_hip(code,
target_format="hsaco",
arch=None,
options=None,
path_target=None,
verbose=False):
def compile_hip(code, target_format="hsaco", arch=None, options=None, path_target=None, verbose=False):
"""Compile HIP code with hipcc.
Parameters
......@@ -61,7 +56,7 @@ def compile_hip(code,
file_target = path_target if path_target else temp_target
cmd = ["hipcc"]
cmd += ["-O3", '-c']
cmd += ["-O3", "-c"]
if isinstance(arch, str):
cmd += [f"--offload-arch={arch}"]
if target_format == "hsaco":
......
# pylint: disable=invalid-name
# modified from apache tvm python/tvm/contrib/nvcc.py
"""Utility to invoke nvcc compiler in the system"""
from __future__ import annotations
import os
......@@ -18,12 +19,7 @@ from tvm.base import py_str
from tvm.contrib import utils
def compile_cuda(code,
target_format="ptx",
arch=None,
options=None,
path_target=None,
verbose=False):
def compile_cuda(code, target_format="ptx", arch=None, options=None, path_target=None, verbose=False):
"""Compile cuda code with NVCC from env.
Parameters
......@@ -67,7 +63,7 @@ def compile_cuda(code,
temp_target = temp.relpath(f"{file_name}.{target_format}")
pass_context = tvm.get_global_func("transform.GetCurrentPassContext")()
kernels_output_dir = (pass_context.config.get("cuda.kernels_output_dir", None))
kernels_output_dir = pass_context.config.get("cuda.kernels_output_dir", None)
if kernels_output_dir is not None:
if not os.path.isdir(kernels_output_dir):
os.makedirs(kernels_output_dir)
......@@ -114,10 +110,7 @@ def compile_cuda(code,
print(py_str(out))
if proc.returncode != 0:
msg = f"{code}\n" \
f"Compilation error:\n" \
f"{py_str(out)}\n" \
f"Command: {' '.join(cmd)}\n"
msg = f"{code}\nCompilation error:\n{py_str(out)}\nCommand: {' '.join(cmd)}\n"
raise RuntimeError(msg)
with open(file_target, "rb") as f:
......@@ -165,6 +158,7 @@ def default_compile_options(compile_flags: list[str] | None = None) -> list[str]
# (e.g., multiple "-gencode" pairs or repeated "-Xcompiler" entries).
if compile_flags:
import shlex
for flag in compile_flags:
# Split each string like a shell would, preserving quoted args
tokens = shlex.split(flag) if isinstance(flag, str) else [str(flag)]
......@@ -172,9 +166,7 @@ def default_compile_options(compile_flags: list[str] | None = None) -> list[str]
return options
def get_ptx_from_source(code: str,
compile_flags: list[str] | None = None,
verbose: bool = False) -> str:
def get_ptx_from_source(code: str, compile_flags: list[str] | None = None, verbose: bool = False) -> str:
"""
Compile CUDA C++ source to PTX using NVCC and return as text.
......@@ -212,9 +204,7 @@ def _find_tool(name: str) -> str | None:
return None
def get_sass_from_source(code: str,
compile_flags: list[str] | None = None,
verbose: bool = False) -> str:
def get_sass_from_source(code: str, compile_flags: list[str] | None = None, verbose: bool = False) -> str:
"""
Compile CUDA C++ source to CUBIN and disassemble to SASS.
......@@ -246,9 +236,7 @@ def get_sass_from_source(code: str,
cand_nvdisasm = _find_tool("nvdisasm")
cand_cuobjdump = _find_tool("cuobjdump")
if not cand_nvdisasm and not cand_cuobjdump:
raise RuntimeError(
"Cannot find 'nvdisasm' or 'cuobjdump'. Please ensure CUDA toolkit is installed and in PATH."
)
raise RuntimeError("Cannot find 'nvdisasm' or 'cuobjdump'. Please ensure CUDA toolkit is installed and in PATH.")
last_err: str | None = None
try:
# Attempt nvdisasm first
......@@ -268,8 +256,7 @@ def get_sass_from_source(code: str,
return text
last_err = f"{tool_name} rc={proc.returncode}, output:\n{text}"
# If we reach here, all attempts failed
raise RuntimeError(f"SASS disassembly failed. Tried tools: "
f"{', '.join(name for name, _ in tools_to_try)}\n{last_err or ''}")
raise RuntimeError(f"SASS disassembly failed. Tried tools: {', '.join(name for name, _ in tools_to_try)}\n{last_err or ''}")
finally:
with contextlib.suppress(Exception):
os.remove(cubin_path)
......@@ -438,8 +425,7 @@ def get_target_compute_version(target=None):
if tvm.cuda(0).exist:
return tvm.cuda(0).compute_version
raise ValueError("No CUDA architecture was specified or GPU detected."
"Try specifying it by adding '-arch=sm_xx' to your target.")
raise ValueError("No CUDA architecture was specified or GPU detected.Try specifying it by adding '-arch=sm_xx' to your target.")
def parse_compute_version(compute_version) -> tuple[int, int]:
......@@ -524,7 +510,8 @@ def have_tensorcore(compute_version=None, target=None):
warnings.warn(
"Tensorcore will be disabled due to no CUDA architecture specified."
"Try specifying it by adding '-arch=sm_xx' to your target.",
stacklevel=2)
stacklevel=2,
)
return False
compute_version = target.attrs["arch"]
# Compute version will be in the form "sm_{major}{minor}"
......
......@@ -11,11 +11,13 @@ def get_nvrtc_version() -> tuple[int, int]:
return (major, minor)
def compile_cuda(code: str,
def compile_cuda(
code: str,
target_format: Literal["ptx", "cubin"] = "ptx",
arch: int | None = None,
options: str | list[str] | None = None,
verbose: bool = False) -> bytearray:
verbose: bool = False,
) -> bytearray:
"""Compile cuda code with NVRTC.
Parameters
......@@ -43,8 +45,7 @@ def compile_cuda(code: str,
if arch is None:
# If None, then it will use `tvm.target.Target.current().arch`.
# Target arch could be a str like "80", "90", "90a", etc.
major, minor = parse_compute_version(
get_target_compute_version(Target.current(allow_none=True)))
major, minor = parse_compute_version(get_target_compute_version(Target.current(allow_none=True)))
arch = major * 10 + minor
prefix = "compute" if target_format == "ptx" else "sm"
suffix = "a" if arch >= 90 else ""
......@@ -77,8 +78,7 @@ def compile_cuda(code: str,
compile_result = nvrtc.nvrtcCompileProgram(program, len(options_bytes), options_bytes)[0]
if compile_result != nvrtc.nvrtcResult.NVRTC_SUCCESS:
msg = f"{code}\n" \
f"Compilation error:\n"
msg = f"{code}\nCompilation error:\n"
if verbose:
result, log_size = nvrtc.nvrtcGetProgramLogSize(program)
assert result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f"Failed to get program log size: {result}"
......@@ -105,7 +105,6 @@ def compile_cuda(code: str,
assert result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f"Failed to get PTX: {result}"
# Destroy handler
assert nvrtc.nvrtcDestroyProgram(
program)[0] == nvrtc.nvrtcResult.NVRTC_SUCCESS, f"Failed to destroy program: {result}"
assert nvrtc.nvrtcDestroyProgram(program)[0] == nvrtc.nvrtcResult.NVRTC_SUCCESS, f"Failed to destroy program: {result}"
return result_bytes
......@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
"""Utility for ROCm backend"""
# ruff: noqa
import re
import subprocess
......@@ -255,9 +256,11 @@ def get_rocm_arch(rocm_path="/opt/rocm"):
gpu_arch = match.group(1)
return gpu_arch
except subprocess.CalledProcessError:
print(f"Unable to execute rocminfo command, \
print(
f"Unable to execute rocminfo command, \
please ensure ROCm is installed and you have an AMD GPU on your system.\
using default {gpu_arch}.")
using default {gpu_arch}."
)
return gpu_arch
......
"""The compiler for TL programs."""
from __future__ import annotations
import os
......@@ -28,14 +29,13 @@ def is_cpu_device_backend(target: Target):
def has_device_kernel_launch(attrs) -> bool:
"""Check if the attributes indicate a device kernel launch."""
return bool(attrs and "calling_conv" in attrs and
attrs["calling_conv"] == CallingConv.DEVICE_KERNEL_LAUNCH)
return bool(attrs and "calling_conv" in attrs and attrs["calling_conv"] == CallingConv.DEVICE_KERNEL_LAUNCH)
def is_device_call_c_device(func: tir.PrimFunc):
attrs = func.attrs
calling_conv = attrs.get("calling_conv", CallingConv.DEFAULT)
is_cpacked = (calling_conv == CallingConv.C_PACKED_FUNC)
is_cpacked = calling_conv == CallingConv.C_PACKED_FUNC
# Check if it's a C target
if "target" in attrs and attrs["target"].kind.name == "c" and not is_cpacked:
......@@ -141,16 +141,16 @@ def extrac_params(func: tir.PrimFunc) -> list[KernelParam]:
if var in func.buffer_map:
tensor_types.append(KernelParam.from_buffer(func.buffer_map[var]))
else:
if var.dtype == 'handle':
if var.dtype == "handle":
raise ValueError(
f'Handle parameter {var} must be mapped to a buffer.\n'
f'Please use T.tensor({var.name}, shape=..., dtype=...) to map it to a buffer.')
f"Handle parameter {var} must be mapped to a buffer.\n"
f"Please use T.tensor({var.name}, shape=..., dtype=...) to map it to a buffer."
)
tensor_types.append(KernelParam.from_var(var))
return tensor_types
def canon_target_host(target: str | Target, target_host: str | Target | None):
if not target_host:
target_host = "llvm" if tvm.runtime.enabled("llvm") else "c"
......@@ -195,11 +195,9 @@ def device_codegen_without_compile(device_mod: tvm.IRModule, target: Target) ->
device_mod = tilelang.transform.LowerIntrin()(device_mod)
device_mod = tir.transform.Simplify()(device_mod)
if target.kind.name == "cuda":
device_mod = tvm.ffi.get_global_func("target.build.tilelang_cuda_without_compile")(
device_mod, target)
device_mod = tvm.ffi.get_global_func("target.build.tilelang_cuda_without_compile")(device_mod, target)
elif target.kind.name == "hip":
device_mod = tvm.ffi.get_global_func("target.build.tilelang_hip_without_compile")(
device_mod, target)
device_mod = tvm.ffi.get_global_func("target.build.tilelang_hip_without_compile")(device_mod, target)
elif target.kind.name == "c":
device_mod = tvm.ffi.get_global_func("target.build.tilelang_cpp")(device_mod, target)
elif target.kind.name == "llvm":
......@@ -222,12 +220,12 @@ def lower(
enable_host_codegen=False,
enable_device_compile=False,
) -> CompiledArtifact:
'''
"""
enable_host_codegen: whether to enable host codegen, default is False, as we have our
own host codegen implementation in jit.
enable_device_compile: whether to enable device codegen, default is False, as we have our
own device codegen implementation in jit.
'''
"""
mod = func_or_mod
params = None
......@@ -259,14 +257,11 @@ def lower(
host_mod = tir.transform.Filter(_is_host_call)(mod)
device_mod = tir.transform.Filter(_is_device_call)(mod)
codegen_mod = device_codegen(
device_mod, target) if enable_device_compile else device_codegen_without_compile(
device_mod, target)
codegen_mod = device_codegen(device_mod, target) if enable_device_compile else device_codegen_without_compile(device_mod, target)
if enable_host_codegen:
host_mod = host_codegen(host_mod, target_host)
host_mod.import_module(codegen_mod)
return CompiledArtifact(
host_mod, device_mod, params, codegen_mod.inspect_source(), rt_mod=host_mod)
return CompiledArtifact(host_mod, device_mod, params, codegen_mod.inspect_source(), rt_mod=host_mod)
return CompiledArtifact(host_mod, device_mod, params, codegen_mod.inspect_source())
"""The profiler and convert to torch utils"""
from __future__ import annotations
from dataclasses import dataclass
......@@ -14,6 +15,7 @@ class KernelParam:
Represents parameters for a kernel operation, storing dtype and shape information.
Used to describe tensor or scalar parameters in TVM/PyTorch interop.
"""
dtype: torch.dtype # PyTorch data type of the parameter
shape: list[int | Var] # List of dimensions, can be integers or TVM variables
......@@ -109,6 +111,7 @@ class CompiledArtifact:
Represents a compiled kernel artifact containing both host and device code.
Stores all necessary components for kernel execution in the TVM runtime.
"""
host_mod: tvm.IRModule # Host-side TVM IR module for managing kernel execution
device_mod: tvm.IRModule # Device-side TVM IR module containing the actual kernel code
params: list[KernelParam] # List of parameters (tensors/scalars) used by the kernel
......
......@@ -6,8 +6,7 @@ from tilelang.transform import PassContext
from tilelang.contrib.nvcc import have_tma, is_hopper
def allow_warp_specialized(pass_ctx: PassContext | None = None,
target: Target | None = None) -> bool:
def allow_warp_specialized(pass_ctx: PassContext | None = None, target: Target | None = None) -> bool:
# avoid circular import
from tilelang.jit.adapter.utils import is_cuda_target
......@@ -19,8 +18,7 @@ def allow_warp_specialized(pass_ctx: PassContext | None = None,
return not disable_warp_specialized
def allow_tma_and_warp_specialized(pass_ctx: PassContext | None = None,
target: Target | None = None) -> bool:
def allow_tma_and_warp_specialized(pass_ctx: PassContext | None = None, target: Target | None = None) -> bool:
if pass_ctx is None:
pass_ctx = tilelang.transform.get_pass_context()
if not have_tma(target):
......@@ -47,12 +45,10 @@ def allow_global_thread_synchronization(pass_ctx: PassContext | None = None) ->
return enable_global_thread_sync
def should_enable_aggressive_merge(pass_ctx: PassContext | None = None,
target: Target | None = None) -> bool:
def should_enable_aggressive_merge(pass_ctx: PassContext | None = None, target: Target | None = None) -> bool:
if pass_ctx is None:
pass_ctx = tilelang.transform.get_pass_context()
enable_aggressive_merge = bool(
pass_ctx.config.get(tilelang.PassConfigKey.TL_ENABLE_AGGRESSIVE_SHARED_MEMORY_MERGE, False))
enable_aggressive_merge = bool(pass_ctx.config.get(tilelang.PassConfigKey.TL_ENABLE_AGGRESSIVE_SHARED_MEMORY_MERGE, False))
if allow_warp_specialized(pass_ctx=pass_ctx, target=target):
# This is a workaround to avoid the bug in the MergeSharedMemoryAllocations pass
# when warp specialization is enabled, as different warp threads may access different
......@@ -88,7 +84,7 @@ def get_layout_visual_formats(pass_ctx: PassContext | None = None) -> list[str]:
return ["txt", "png", "pdf", "svg"]
if "," in formats_str:
formats_list = [f.strip() for f in formats_str.split(',')]
formats_list = [f.strip() for f in formats_str.split(",")]
else:
formats_list = [formats_str]
......@@ -257,9 +253,7 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
# MergeSharedMemoryAllocations must be applied after SplitHostDevice
# because the merged allocation site is at the beginning of each device function
enable_aggressive_merge = should_enable_aggressive_merge(pass_ctx=pass_ctx, target=target)
mod = tilelang.transform.MergeSharedMemoryAllocations(
enable_aggressive_merge=enable_aggressive_merge)(
mod)
mod = tilelang.transform.MergeSharedMemoryAllocations(enable_aggressive_merge=enable_aggressive_merge)(mod)
mod = tilelang.transform.ThreadSync("shared")(mod)
mod = tilelang.transform.ThreadSync("shared.dyn")(mod)
# Inject PTX async copy must behind the thread sync pass
......
......@@ -10,36 +10,34 @@ from dataclasses import dataclass
logger = logging.getLogger(__name__)
# SETUP ENVIRONMENT VARIABLES
CUTLASS_NOT_FOUND_MESSAGE = ("CUTLASS is not installed or found in the expected path")
CUTLASS_NOT_FOUND_MESSAGE = "CUTLASS is not installed or found in the expected path"
", which may lead to compilation bugs when utilize tilelang backend."
COMPOSABLE_KERNEL_NOT_FOUND_MESSAGE = (
"Composable Kernel is not installed or found in the expected path")
COMPOSABLE_KERNEL_NOT_FOUND_MESSAGE = "Composable Kernel is not installed or found in the expected path"
", which may lead to compilation bugs when utilize tilelang backend."
TL_TEMPLATE_NOT_FOUND_MESSAGE = ("TileLang is not installed or found in the expected path")
TL_TEMPLATE_NOT_FOUND_MESSAGE = "TileLang is not installed or found in the expected path"
", which may lead to compilation bugs when utilize tilelang backend."
TVM_LIBRARY_NOT_FOUND_MESSAGE = ("TVM is not installed or found in the expected path")
TVM_LIBRARY_NOT_FOUND_MESSAGE = "TVM is not installed or found in the expected path"
TL_ROOT = os.path.dirname(os.path.abspath(__file__))
# Only expose the internal lib directory to sys.path to avoid shadowing
# common top-level module names (e.g., utils, analysis) from user projects.
TL_LIBS = [os.path.join(TL_ROOT, 'lib')]
TL_LIBS = [os.path.join(TL_ROOT, "lib")]
TL_LIBS = [i for i in TL_LIBS if os.path.exists(i)]
DEV = False
THIRD_PARTY_ROOT = os.path.join(TL_ROOT, '3rdparty')
THIRD_PARTY_ROOT = os.path.join(TL_ROOT, "3rdparty")
if not os.path.exists(THIRD_PARTY_ROOT):
DEV = True
tl_dev_root = os.path.dirname(TL_ROOT)
dev_lib_root = os.path.join(tl_dev_root, 'build')
dev_lib_root = os.path.join(tl_dev_root, "build")
# In dev builds, place artifacts under build/lib and point search path there
# to avoid adding the entire build root to sys.path.
TL_LIBS = [os.path.join(dev_lib_root, 'lib'), os.path.join(dev_lib_root, 'tvm')]
THIRD_PARTY_ROOT = os.path.join(tl_dev_root, '3rdparty')
logger.warning(f'Loading tilelang libs from dev root: {dev_lib_root}')
TL_LIBS = [os.path.join(dev_lib_root, "lib"), os.path.join(dev_lib_root, "tvm")]
THIRD_PARTY_ROOT = os.path.join(tl_dev_root, "3rdparty")
logger.warning(f"Loading tilelang libs from dev root: {dev_lib_root}")
assert TL_LIBS and all(
os.path.exists(i) for i in TL_LIBS), f'tilelang lib root do not exists: {TL_LIBS}'
assert TL_LIBS and all(os.path.exists(i) for i in TL_LIBS), f"tilelang lib root do not exists: {TL_LIBS}"
for lib in TL_LIBS:
if lib not in sys.path:
......@@ -52,7 +50,7 @@ def _find_cuda_home() -> str:
Adapted from https://github.com/pytorch/pytorch/blob/main/torch/utils/cpp_extension.py
"""
# Guess #1
cuda_home = os.environ.get('CUDA_HOME') or os.environ.get('CUDA_PATH')
cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH")
if cuda_home is None:
# Guess #2
nvcc_path = shutil.which("nvcc")
......@@ -70,15 +68,15 @@ def _find_cuda_home() -> str:
else:
# Guess #3
if sys.platform == 'win32':
cuda_homes = glob.glob('C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v*.*')
cuda_home = '' if len(cuda_homes) == 0 else cuda_homes[0]
if sys.platform == "win32":
cuda_homes = glob.glob("C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v*.*")
cuda_home = "" if len(cuda_homes) == 0 else cuda_homes[0]
else:
# Linux/macOS
if os.path.exists('/usr/local/cuda'):
cuda_home = '/usr/local/cuda'
elif os.path.exists('/opt/nvidia/hpc_sdk/Linux_x86_64'):
cuda_home = '/opt/nvidia/hpc_sdk/Linux_x86_64'
if os.path.exists("/usr/local/cuda"):
cuda_home = "/usr/local/cuda"
elif os.path.exists("/opt/nvidia/hpc_sdk/Linux_x86_64"):
cuda_home = "/opt/nvidia/hpc_sdk/Linux_x86_64"
# Validate found path
if cuda_home is None or not os.path.exists(cuda_home):
......@@ -89,13 +87,13 @@ def _find_cuda_home() -> str:
def _find_rocm_home() -> str:
"""Find the ROCM install path."""
rocm_home = os.environ.get('ROCM_PATH') or os.environ.get('ROCM_HOME')
rocm_home = os.environ.get("ROCM_PATH") or os.environ.get("ROCM_HOME")
if rocm_home is None:
rocmcc_path = shutil.which("hipcc")
if rocmcc_path is not None:
rocm_home = os.path.dirname(os.path.dirname(rocmcc_path))
else:
rocm_home = '/opt/rocm'
rocm_home = "/opt/rocm"
if not os.path.exists(rocm_home):
rocm_home = None
return rocm_home if rocm_home is not None else ""
......@@ -104,6 +102,7 @@ def _find_rocm_home() -> str:
# Cache control
class CacheState:
"""Class to manage global kernel caching state."""
_enabled = True
@classmethod
......@@ -230,13 +229,11 @@ class Environment:
TILELANG_TMP_DIR = EnvVar("TILELANG_TMP_DIR", os.path.join(TILELANG_CACHE_DIR.get(), "tmp"))
# Kernel Build options
TILELANG_PRINT_ON_COMPILATION = EnvVar("TILELANG_PRINT_ON_COMPILATION",
"1") # print kernel name on compile
TILELANG_PRINT_ON_COMPILATION = EnvVar("TILELANG_PRINT_ON_COMPILATION", "1") # print kernel name on compile
TILELANG_DISABLE_CACHE = EnvVar(
"TILELANG_DISABLE_CACHE",
"0") # disable kernel cache, usually for unit testing / debugging, high priority
TILELANG_CLEAR_CACHE = EnvVar("TILELANG_CLEAR_CACHE",
"0") # DEPRECATED! clear cache automatically if set
"TILELANG_DISABLE_CACHE", "0"
) # disable kernel cache, usually for unit testing / debugging, high priority
TILELANG_CLEAR_CACHE = EnvVar("TILELANG_CLEAR_CACHE", "0") # DEPRECATED! clear cache automatically if set
# Kernel selection options
# Default to GEMM v2; set to "1"/"true"/"yes"/"on" to force v1
......@@ -244,12 +241,9 @@ class Environment:
# Auto-tuning settings
TILELANG_AUTO_TUNING_DISABLE_CACHE = EnvVar("TILELANG_AUTO_TUNING_DISABLE_CACHE", "0")
TILELANG_AUTO_TUNING_CPU_UTILITIES = EnvVar("TILELANG_AUTO_TUNING_CPU_UTILITIES",
"0.9") # percent of CPUs used
TILELANG_AUTO_TUNING_CPU_COUNTS = EnvVar("TILELANG_AUTO_TUNING_CPU_COUNTS",
"-1") # -1 means auto
TILELANG_AUTO_TUNING_MAX_CPU_COUNT = EnvVar("TILELANG_AUTO_TUNING_MAX_CPU_COUNT",
"-1") # -1 means no limit
TILELANG_AUTO_TUNING_CPU_UTILITIES = EnvVar("TILELANG_AUTO_TUNING_CPU_UTILITIES", "0.9") # percent of CPUs used
TILELANG_AUTO_TUNING_CPU_COUNTS = EnvVar("TILELANG_AUTO_TUNING_CPU_COUNTS", "-1") # -1 means auto
TILELANG_AUTO_TUNING_MAX_CPU_COUNT = EnvVar("TILELANG_AUTO_TUNING_MAX_CPU_COUNT", "-1") # -1 means no limit
# TVM integration
SKIP_LOADING_TILELANG_SO = EnvVar("SKIP_LOADING_TILELANG_SO", "0")
......@@ -323,18 +317,18 @@ def prepend_pythonpath(path):
if env.TVM_IMPORT_PYTHON_PATH is not None:
prepend_pythonpath(env.TVM_IMPORT_PYTHON_PATH)
else:
tvm_path = os.path.join(THIRD_PARTY_ROOT, 'tvm', 'python')
tvm_path = os.path.join(THIRD_PARTY_ROOT, "tvm", "python")
assert os.path.exists(tvm_path), tvm_path
if tvm_path not in sys.path:
prepend_pythonpath(tvm_path)
env.TVM_IMPORT_PYTHON_PATH = tvm_path
# By default, the built TVM-related libraries are stored in TL_LIBS.
if os.environ.get("TVM_LIBRARY_PATH") is None:
os.environ['TVM_LIBRARY_PATH'] = env.TVM_LIBRARY_PATH = os.pathsep.join(TL_LIBS)
os.environ["TVM_LIBRARY_PATH"] = env.TVM_LIBRARY_PATH = os.pathsep.join(TL_LIBS)
# Initialize CUTLASS paths
if os.environ.get("TL_CUTLASS_PATH", None) is None:
cutlass_inc_path = os.path.join(THIRD_PARTY_ROOT, 'cutlass', 'include')
cutlass_inc_path = os.path.join(THIRD_PARTY_ROOT, "cutlass", "include")
if os.path.exists(cutlass_inc_path):
os.environ["TL_CUTLASS_PATH"] = env.CUTLASS_INCLUDE_DIR = cutlass_inc_path
else:
......@@ -342,7 +336,7 @@ if os.environ.get("TL_CUTLASS_PATH", None) is None:
# Initialize COMPOSABLE_KERNEL paths
if os.environ.get("TL_COMPOSABLE_KERNEL_PATH", None) is None:
ck_inc_path = os.path.join(THIRD_PARTY_ROOT, 'composable_kernel', 'include')
ck_inc_path = os.path.join(THIRD_PARTY_ROOT, "composable_kernel", "include")
if os.path.exists(ck_inc_path):
os.environ["TL_COMPOSABLE_KERNEL_PATH"] = env.COMPOSABLE_KERNEL_INCLUDE_DIR = ck_inc_path
else:
......
......@@ -4,7 +4,7 @@ import tilelang.language as T
def shared_16x4_to_local_64x1_layout_A(i, j):
thread_id = (j * 16 + i)
thread_id = j * 16 + i
return thread_id, convert(0)
......@@ -15,7 +15,7 @@ def thread_id_shared_access_64x1_to_16x4_layout_A(thread_id, local_id):
def shared_4x16_to_local_64x1_layout_B(i, j):
thread_id = (i * 16 + j)
thread_id = i * 16 + j
return thread_id, convert(0)
......@@ -27,7 +27,7 @@ def thread_id_shared_access_64x1_to_4x16_layout_B(thread_id, local_id):
def shared_16x16_to_local_64x4_layout_C(i, j):
thread_id = j + (i // 4) * 16
local = (i % 4)
local = i % 4
return thread_id, local
......@@ -45,7 +45,7 @@ def thread_id_shared_access_64x4_to_16x16_layout_A(thread_id, local_id):
def shared_16x16_to_local_64x4_layout_A(i, j):
thread_id = i + 16 * (j // 4)
local = (j % 4)
local = j % 4
return thread_id, local
......@@ -57,7 +57,7 @@ def thread_id_shared_access_64x4_to_16x16_layout_B(thread_id, local_id):
def shared_16x16_to_local_64x4_layout_B(i, j):
thread_id = j + (i // 4) * 16
local = (i % 4)
local = i % 4
return thread_id, local
......@@ -87,7 +87,7 @@ def thread_id_shared_access_64x8_to_16x32_layout_A(thread_id, local_id):
def shared_16x32_to_local_64x8_layout_A(i, j):
thread_id = i + 16 * (j // 8)
local = (j % 8)
local = j % 8
return thread_id, local
......@@ -99,7 +99,7 @@ def thread_id_shared_access_64x8_to_16x32_layout_B(thread_id, local_id):
def shared_16x32_to_local_64x8_layout_B(i, j):
thread_id = j + (i // 8) * 16
local = (i % 8)
local = i % 8
return thread_id, local
......@@ -111,7 +111,7 @@ def thread_id_shared_access_64x16_to_16x64_layout_A(thread_id, local_id):
def shared_16x64_to_local_64x16_layout_A(i, j):
thread_id = i + 16 * (j // 16)
local = (j % 16)
local = j % 16
return thread_id, local
......@@ -123,7 +123,7 @@ def thread_id_shared_access_64x16_to_16x64_layout_B(thread_id, local_id):
def shared_16x64_to_local_64x16_layout_B(i, j):
thread_id = i + 16 * (j // 16)
local = (j % 16)
local = j % 16
return thread_id, local
......
......@@ -6,7 +6,7 @@ from tvm import tir
from tvm.ir import Range
from tvm.tir import PrimExpr, IndexMap, Buffer, Var, BufferRegion, BufferLoad
from tvm.runtime import convert
from .utils import (mfma_store_index_map)
from .utils import mfma_store_index_map
from typing import Literal, Callable
from tilelang.utils import is_fragment
......@@ -101,7 +101,7 @@ class MatrixCoreIntrinEmitter:
self.warp_rows = warp_row_tiles // self.micro_size_x
self.warp_cols = warp_col_tiles // self.micro_size_y
self.reduce_k = reduce_k
self.threads = (self.WARP_SIZE * (block_row_warps * block_col_warps) * reduce_k)
self.threads = self.WARP_SIZE * (block_row_warps * block_col_warps) * reduce_k
self.num_elems_per_byte = num_elems_per_byte
self.thread_var = thread_var
......@@ -132,12 +132,7 @@ class MatrixCoreIntrinEmitter:
def _initialize_mfma_prefix(self, k_dim=16):
in_dtype, out_dtype = self.a_dtype, self.accum_dtype
M_DIM, N_DIM = self.M_DIM, self.N_DIM
out_dtype_abbrv = {
"float16": "f16",
"float32": "f32",
"int8": "i8",
"int32": "i32"
}[out_dtype]
out_dtype_abbrv = {"float16": "f16", "float32": "f32", "int8": "i8", "int32": "i32"}[out_dtype]
in_dtype_abbrv = {
"bfloat16": "bf16",
......@@ -176,7 +171,6 @@ class MatrixCoreIntrinEmitter:
self.b_preshuffle = b_preshuffle
def get_ldmatrix_index_map(self, is_b=False):
k_dim = self.k_dim * self.k_pack
transposed = self.a_transposed if not is_b else self.b_transposed
if k_dim == 4:
......@@ -184,28 +178,42 @@ class MatrixCoreIntrinEmitter:
reverse_index_map = thread_id_shared_access_64x1_to_16x4_layout_A
if is_b:
index_map = shared_16x4_to_local_64x1_layout_A if transposed else shared_4x16_to_local_64x1_layout_B
reverse_index_map = thread_id_shared_access_64x1_to_16x4_layout_A if transposed else thread_id_shared_access_64x1_to_4x16_layout_B
reverse_index_map = (
thread_id_shared_access_64x1_to_16x4_layout_A if transposed else thread_id_shared_access_64x1_to_4x16_layout_B
)
elif k_dim == 16:
index_map = shared_16x16_to_local_64x4_layout_B if transposed else shared_16x16_to_local_64x4_layout_A
reverse_index_map = thread_id_shared_access_64x4_to_16x16_layout_B if transposed else thread_id_shared_access_64x4_to_16x16_layout_A
reverse_index_map = (
thread_id_shared_access_64x4_to_16x16_layout_B if transposed else thread_id_shared_access_64x4_to_16x16_layout_A
)
if is_b:
index_map = shared_16x16_to_local_64x4_layout_A if transposed else shared_16x16_to_local_64x4_layout_B
reverse_index_map = thread_id_shared_access_64x4_to_16x16_layout_A if transposed else thread_id_shared_access_64x4_to_16x16_layout_B
reverse_index_map = (
thread_id_shared_access_64x4_to_16x16_layout_A if transposed else thread_id_shared_access_64x4_to_16x16_layout_B
)
elif k_dim == 32:
index_map = shared_16x32_to_local_64x8_layout_B if transposed else shared_16x32_to_local_64x8_layout_A
reverse_index_map = thread_id_shared_access_64x8_to_16x32_layout_B if transposed else thread_id_shared_access_64x8_to_16x32_layout_A
reverse_index_map = (
thread_id_shared_access_64x8_to_16x32_layout_B if transposed else thread_id_shared_access_64x8_to_16x32_layout_A
)
if is_b:
index_map = shared_16x32_to_local_64x8_layout_A if transposed else shared_16x32_to_local_64x8_layout_B
reverse_index_map = thread_id_shared_access_64x8_to_16x32_layout_A if transposed else thread_id_shared_access_64x8_to_16x32_layout_B
reverse_index_map = (
thread_id_shared_access_64x8_to_16x32_layout_A if transposed else thread_id_shared_access_64x8_to_16x32_layout_B
)
elif k_dim == 64:
index_map = shared_16x64_to_local_64x16_layout_B if transposed else shared_16x64_to_local_64x16_layout_A
reverse_index_map = thread_id_shared_access_64x16_to_16x64_layout_B if transposed else thread_id_shared_access_64x16_to_16x64_layout_A
reverse_index_map = (
thread_id_shared_access_64x16_to_16x64_layout_B if transposed else thread_id_shared_access_64x16_to_16x64_layout_A
)
if is_b:
index_map = shared_16x64_to_local_64x16_layout_A if transposed else shared_16x64_to_local_64x16_layout_B
reverse_index_map = thread_id_shared_access_64x16_to_16x64_layout_A if transposed else thread_id_shared_access_64x16_to_16x64_layout_B
reverse_index_map = (
thread_id_shared_access_64x16_to_16x64_layout_A if transposed else thread_id_shared_access_64x16_to_16x64_layout_B
)
else:
raise ValueError("k_dim must be 4 or 16 or 32 or 64 currently")
......@@ -227,14 +235,12 @@ class MatrixCoreIntrinEmitter:
else:
return self.thread_var
def extract_thread_binding(self,
thread_id,
is_m_first=None) -> tuple[PrimExpr, PrimExpr, PrimExpr]:
'''
def extract_thread_binding(self, thread_id, is_m_first=None) -> tuple[PrimExpr, PrimExpr, PrimExpr]:
"""
is_m_first: True if the thread binding is in the form of (tx, warp_n, warp_m)
which represents [warp_size, block_row_warps (split n), block_col_warps (split m)]
Otherwise, it is in the form of [warp_size, block_col_warps (split m), block_row_warps (split n)]
'''
"""
WARP_SIZE = self.WARP_SIZE
block_row_warps = self.block_row_warps
block_col_warps = self.block_col_warps
......@@ -244,16 +250,18 @@ class MatrixCoreIntrinEmitter:
is_m_first = self.is_m_first
if is_m_first:
lane_id, warp_n, warp_m = thread_id % WARP_SIZE, (
thread_id //
WARP_SIZE) % block_col_warps, (thread_id //
(WARP_SIZE * block_col_warps)) % block_row_warps,
lane_id, warp_n, warp_m = (
thread_id % WARP_SIZE,
(thread_id // WARP_SIZE) % block_col_warps,
(thread_id // (WARP_SIZE * block_col_warps)) % block_row_warps,
)
return lane_id, warp_n, warp_m
else:
lane_id, warp_m, warp_n = thread_id % WARP_SIZE, (
thread_id //
WARP_SIZE) % block_row_warps, (thread_id //
(WARP_SIZE * block_row_warps)) % block_col_warps,
lane_id, warp_m, warp_n = (
thread_id % WARP_SIZE,
(thread_id // WARP_SIZE) % block_row_warps,
(thread_id // (WARP_SIZE * block_row_warps)) % block_col_warps,
)
return lane_id, warp_n, warp_m
def ldmatrix_a(self, A_local_buf, A_shared_buf: Buffer | BufferRegion, ki, rk=0):
......@@ -287,18 +295,14 @@ class MatrixCoreIntrinEmitter:
for i in T.serial(warp_rows):
for local_id in T.vectorized(k_pack * local_size_a):
row, col = T.meta_var(reverse_index_map(tx, local_id))
l, r = (rk * chunk + ki * (k_pack * micro_size_k),
warp_m * warp_row_tiles + i * micro_size_x)
A_local_buf[i * k_pack * local_size_a + local_id] = A_buf[A_base0 + l + row,
A_base1 + r + col]
l, r = (rk * chunk + ki * (k_pack * micro_size_k), warp_m * warp_row_tiles + i * micro_size_x)
A_local_buf[i * k_pack * local_size_a + local_id] = A_buf[A_base0 + l + row, A_base1 + r + col]
else:
for i in T.serial(warp_rows):
for local_id in T.vectorized(k_pack * local_size_a):
row, col = T.meta_var(reverse_index_map(tx, local_id))
l, r = (warp_m * warp_row_tiles + i * micro_size_x,
rk * chunk + ki * (k_pack * micro_size_k))
A_local_buf[i * k_pack * local_size_a + local_id] = A_buf[A_base0 + l + row,
A_base1 + r + col]
l, r = (warp_m * warp_row_tiles + i * micro_size_x, rk * chunk + ki * (k_pack * micro_size_k))
A_local_buf[i * k_pack * local_size_a + local_id] = A_buf[A_base0 + l + row, A_base1 + r + col]
return _warp_ldmatrix_a(A_local_buf, A_shared_buf, ki, thread_binding, rk)
......@@ -337,8 +341,7 @@ class MatrixCoreIntrinEmitter:
warp_n * warp_col_tiles + j * micro_size_y,
rk * chunk + ki * (k_pack * micro_size_k),
)
B_local_buf[j * k_pack * local_size_b + local_id] = B_buf[B_base0 + l + row,
B_base1 + r + col]
B_local_buf[j * k_pack * local_size_b + local_id] = B_buf[B_base0 + l + row, B_base1 + r + col]
else:
for j in T.serial(warp_cols):
......@@ -348,16 +351,11 @@ class MatrixCoreIntrinEmitter:
rk * chunk + ki * (k_pack * micro_size_k),
warp_n * warp_col_tiles + j * micro_size_y,
)
B_local_buf[j * k_pack * local_size_b + local_id] = B_buf[B_base0 + l + row,
B_base1 + r + col]
B_local_buf[j * k_pack * local_size_b + local_id] = B_buf[B_base0 + l + row, B_base1 + r + col]
return _warp_ldmatrix_b(B_local_buf, B_shared_buf, ki, thread_binding, rk)
def mfma(self,
A_local_buf: Buffer,
B_local_buf: Buffer,
C_local_buf: Buffer,
k_inner: PrimExpr | None = 0):
def mfma(self, A_local_buf: Buffer, B_local_buf: Buffer, C_local_buf: Buffer, k_inner: PrimExpr | None = 0):
warp_rows = self.warp_rows
warp_cols = self.warp_cols
local_size_a = self.local_size_a
......@@ -421,14 +419,13 @@ class MatrixCoreIntrinEmitter:
for local_id in T.vectorized(local_size_out):
row, col = T.meta_var(mfma_store_index_map(tx, local_id))
if C_buf_dims == 2:
C_buf[(warp_m * warp_rows + i) * M_DIM + row,
(warp_n * warp_cols + j) * N_DIM +
col] = C_local_buf[i * (warp_cols * local_size_out) +
j * local_size_out + local_id]
C_buf[(warp_m * warp_rows + i) * M_DIM + row, (warp_n * warp_cols + j) * N_DIM + col] = C_local_buf[
i * (warp_cols * local_size_out) + j * local_size_out + local_id
]
else:
C_buf[warp_m * warp_rows + i, warp_n * warp_cols + j, row,
col] = C_local_buf[i * warp_cols * local_size_out +
j * local_size_out + local_id]
C_buf[warp_m * warp_rows + i, warp_n * warp_cols + j, row, col] = C_local_buf[
i * warp_cols * local_size_out + j * local_size_out + local_id
]
@T.macro
def _warp_stmatrix_global(C_local_buf, C_buf, thread_binding):
......@@ -436,18 +433,17 @@ class MatrixCoreIntrinEmitter:
for i, j in T.grid(warp_rows, warp_cols):
for local_id in T.vectorized(local_size_out):
row, col = T.meta_var(mfma_store_index_map(tx, local_id))
C_buf[(pid_m * BLOCK_M + warp_m * warp_rows + i) * M_DIM + row,
(pid_n * BLOCK_N + warp_n * warp_cols + j) * N_DIM +
col] = C_local_buf[i * warp_cols * local_size_out + j * local_size_out +
local_id]
return _warp_stmatrix_global(C_local_buf, C_buf,
thread_binding) if is_global else _warp_stmatrix_shared(
C_local_buf, C_buf, thread_binding)
def make_mfma_load_layout(self,
local_buf: Buffer,
matrix: Literal["A", "B"] = "A") -> T.Fragment:
C_buf[
(pid_m * BLOCK_M + warp_m * warp_rows + i) * M_DIM + row, (pid_n * BLOCK_N + warp_n * warp_cols + j) * N_DIM + col
] = C_local_buf[i * warp_cols * local_size_out + j * local_size_out + local_id]
return (
_warp_stmatrix_global(C_local_buf, C_buf, thread_binding)
if is_global
else _warp_stmatrix_shared(C_local_buf, C_buf, thread_binding)
)
def make_mfma_load_layout(self, local_buf: Buffer, matrix: Literal["A", "B"] = "A") -> T.Fragment:
"""
Create a layout function for storing MFMA results into a fragment buffer.
......@@ -468,6 +464,7 @@ class MatrixCoreIntrinEmitter:
If `local_buf` is not detected to be a fragment buffer.
"""
from tilelang.utils import is_fragment
assert matrix in ["A", "B"], "matrix should be either A or B"
matrix_is_a: bool = matrix == "A"
matrix_is_b: bool = matrix == "B"
......@@ -506,11 +503,9 @@ class MatrixCoreIntrinEmitter:
transform_func: Callable = None
if matrix_is_a:
transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a(
j, i)
transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a(j, i)
elif matrix_is_b:
transform_func = transform_func_sr_b if is_sr_axis_order else lambda i, j: transform_func_sr_b(
j, i)
transform_func = transform_func_sr_b if is_sr_axis_order else lambda i, j: transform_func_sr_b(j, i)
else:
raise ValueError(f"Unsupported matrix {matrix}")
......@@ -543,8 +538,7 @@ class MatrixCoreIntrinEmitter:
return local_id
base_fragment = T.Fragment(
[micro_size_s, micro_size_r *
self.k_pack] if is_sr_axis_order else [micro_size_r * self.k_pack, micro_size_s],
[micro_size_s, micro_size_r * self.k_pack] if is_sr_axis_order else [micro_size_r * self.k_pack, micro_size_s],
forward_thread_fn=forward_thread,
forward_index_fn=forward_index,
)
......@@ -558,31 +552,19 @@ class MatrixCoreIntrinEmitter:
replicate = block_col_warps if matrix_is_a else block_row_warps
if is_sr_axis_order:
warp_fragment = base_fragment.repeat([warp_s, warp_r],
repeat_on_thread=False,
lower_dim_first=False)
warp_fragment = base_fragment.repeat([warp_s, warp_r], repeat_on_thread=False, lower_dim_first=False)
if matrix_is_a:
block_fragment = warp_fragment.repeat([block_s, 1],
repeat_on_thread=True,
lower_dim_first=True).replicate(replicate)
block_fragment = warp_fragment.repeat([block_s, 1], repeat_on_thread=True, lower_dim_first=True).replicate(replicate)
elif matrix_is_b:
block_fragment = warp_fragment.replicate(replicate).repeat([block_s, 1],
repeat_on_thread=True,
lower_dim_first=True)
block_fragment = warp_fragment.replicate(replicate).repeat([block_s, 1], repeat_on_thread=True, lower_dim_first=True)
else:
raise ValueError(f"Unsupported matrix type {matrix}")
else:
warp_fragment = base_fragment.repeat([warp_r, warp_s],
repeat_on_thread=False,
lower_dim_first=True)
warp_fragment = base_fragment.repeat([warp_r, warp_s], repeat_on_thread=False, lower_dim_first=True)
if matrix_is_a:
block_fragment = warp_fragment.repeat([1, block_s],
repeat_on_thread=True,
lower_dim_first=True).replicate(replicate)
block_fragment = warp_fragment.repeat([1, block_s], repeat_on_thread=True, lower_dim_first=True).replicate(replicate)
elif matrix_is_b:
block_fragment = warp_fragment.replicate(replicate).repeat([1, block_s],
repeat_on_thread=True,
lower_dim_first=True)
block_fragment = warp_fragment.replicate(replicate).repeat([1, block_s], repeat_on_thread=True, lower_dim_first=True)
else:
raise ValueError(f"Unsupported matrix type {matrix}")
......@@ -686,7 +668,6 @@ class MatrixCoreIntrinEmitter:
class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter):
def __init__(
self,
a_dtype: str = "float16",
......@@ -792,20 +773,20 @@ class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter):
rk * (chunk // micro_size_k) + ki,
warp_m * warp_rows + i,
)
A_local_buf[i * k_pack * local_size_a + local_id] = A_shared_buf[l, r, row,
col]
A_local_buf[i * k_pack * local_size_a + local_id] = A_shared_buf[l, r, row, col]
else:
print(self.a_preshuffle)
for i in T.serial(warp_rows):
for local_id in T.vectorized(k_pack * local_size_a):
row, col = T.meta_var(reverse_index_map(tx, local_id))
l, r = (warp_m * warp_rows + i, rk * (chunk // micro_size_k) + ki)
A_local_buf[i * k_pack * local_size_a + local_id] = A_shared_buf[l, r, row,
col]
A_local_buf[i * k_pack * local_size_a + local_id] = A_shared_buf[l, r, row, col]
return _warp_ldmatrix_a_global(A_local_buf, A_buf, ki, thread_binding,
rk) if is_global else _warp_ldmatrix_a_shared(
A_local_buf, A_buf, ki, thread_binding, rk)
return (
_warp_ldmatrix_a_global(A_local_buf, A_buf, ki, thread_binding, rk)
if is_global
else _warp_ldmatrix_a_shared(A_local_buf, A_buf, ki, thread_binding, rk)
)
def ldmatrix_b(self, B_local_buf, B_buf, ki, rk=0, pid_m=None, pid_n=None):
warp_cols = self.warp_cols
......@@ -867,8 +848,7 @@ class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter):
warp_n * warp_cols + j,
rk * (chunk // micro_size_k) + ki,
)
B_local_buf[j * k_pack * local_size_b + local_id] = B_shared_buf[l, r, row,
col]
B_local_buf[j * k_pack * local_size_b + local_id] = B_shared_buf[l, r, row, col]
else:
for j in T.serial(warp_cols):
for local_id in T.vectorized(k_pack * local_size_b):
......@@ -877,9 +857,10 @@ class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter):
rk * (chunk // micro_size_k) + ki,
warp_n * warp_cols + j,
)
B_local_buf[j * k_pack * local_size_b + local_id] = B_shared_buf[l, r, row,
col]
B_local_buf[j * k_pack * local_size_b + local_id] = B_shared_buf[l, r, row, col]
return _warp_ldmatrix_b_global(B_local_buf, B_buf, ki, thread_binding,
rk) if is_global else _warp_ldmatrix_b_shared(
B_local_buf, B_buf, ki, thread_binding, rk)
return (
_warp_ldmatrix_b_global(B_local_buf, B_buf, ki, thread_binding, rk)
if is_global
else _warp_ldmatrix_b_shared(B_local_buf, B_buf, ki, thread_binding, rk)
)
......@@ -191,6 +191,7 @@ class TensorCoreIntrinEmitter:
def get_store_index_map(self, inverse: bool = False) -> IndexMap:
from .utils import mma_store_index_map, mma_store_index_map_fp64
warp_size, local_size_c = self.WARP_SIZE, self.local_size_out
if DataType(self.accum_dtype).bits == 64:
index_map = IndexMap.from_func(mma_store_index_map_fp64, index_dtype="int32")
......@@ -201,10 +202,7 @@ class TensorCoreIntrinEmitter:
inverse_index_map = index_map.inverse([warp_size, local_size_c])
return inverse_index_map
def extract_thread_binding(
self,
thread_id: PrimExpr,
is_m_first: bool | None = None) -> tuple[PrimExpr, PrimExpr, PrimExpr]:
def extract_thread_binding(self, thread_id: PrimExpr, is_m_first: bool | None = None) -> tuple[PrimExpr, PrimExpr, PrimExpr]:
"""
is_m_first: True if the thread binding is in the form of (tx, warp_n, warp_m)
which represents [warp_size, block_row_warps (split n), block_col_warps (split m)]
......@@ -233,11 +231,7 @@ class TensorCoreIntrinEmitter:
)
return lane_id, warp_n, warp_m
def ldmatrix_a(self,
A_local_buf: Buffer,
A_shared_buf: Buffer | BufferRegion,
ki: PrimExpr,
rk: PrimExpr | None = 0):
def ldmatrix_a(self, A_local_buf: Buffer, A_shared_buf: Buffer | BufferRegion, ki: PrimExpr, rk: PrimExpr | None = 0):
# Fast path for fp64: no ldmatrix support, do direct per-lane loads
if DataType(self.a_dtype).bits == 64:
warp_row_tiles = self.warp_row_tiles
......@@ -324,9 +318,7 @@ class TensorCoreIntrinEmitter:
for i in T.serial(warp_rows):
# Assign A_shared_buf_elem
wi, wk = warp_m * warp_row_tiles + i * micro_size_x, rk * chunk + ki * micro_size_k
A_shared_buf_elem = A_buf[A_base0 + wk,
A_base1 + wi] if a_transposed else A_buf[A_base0 + wi,
A_base1 + wk]
A_shared_buf_elem = A_buf[A_base0 + wk, A_base1 + wi] if a_transposed else A_buf[A_base0 + wi, A_base1 + wk]
if ldmatrix_available:
T.ptx_ldmatrix(
......@@ -343,20 +335,13 @@ class TensorCoreIntrinEmitter:
for j in T.serial(local_size_a):
mi, mk = mma_load_layout(tx, j)
if a_transposed:
A_local_buf[i * local_size_a + j] = A_buf[A_base0 + wk + mk,
A_base1 + wi + mi]
A_local_buf[i * local_size_a + j] = A_buf[A_base0 + wk + mk, A_base1 + wi + mi]
else:
A_local_buf[i * local_size_a + j] = A_buf[A_base0 + wi + mi,
A_base1 + wk + mk]
A_local_buf[i * local_size_a + j] = A_buf[A_base0 + wi + mi, A_base1 + wk + mk]
return _warp_ldmatrix_a(A_local_buf, A_region, ki, thread_binding, rk)
def ldmatrix_b(self,
B_local_buf: Buffer,
B_shared_buf: Buffer | BufferRegion,
ki: PrimExpr,
rk: PrimExpr | None = 0):
def ldmatrix_b(self, B_local_buf: Buffer, B_shared_buf: Buffer | BufferRegion, ki: PrimExpr, rk: PrimExpr | None = 0):
# Fast path for fp64: no ldmatrix support, do direct per-lane loads
if DataType(self.b_dtype).bits == 64:
warp_col_tiles = self.warp_col_tiles
......@@ -411,7 +396,7 @@ class TensorCoreIntrinEmitter:
B_base0 = B_region.region[-2].min
B_base1 = B_region.region[-1].min
B_stride_last = B_buf.shape[-1]
replicate_b = (self.n_dim == 16)
replicate_b = self.n_dim == 16
# ldmatrix cannot be used for int8 + trans case.
ldmatrix_available = not (DataType(b_dtype).bits != 16 and not b_transposed)
......@@ -448,9 +433,7 @@ class TensorCoreIntrinEmitter:
)
if ldmatrix_available:
B_shared_buf_elem = B_buf[B_base0 + wi,
B_base1 + wk] if b_transposed else B_buf[B_base0 + wk,
B_base1 + wi]
B_shared_buf_elem = B_buf[B_base0 + wi, B_base1 + wk] if b_transposed else B_buf[B_base0 + wk, B_base1 + wi]
T.ptx_ldmatrix(
b_dtype,
......@@ -469,19 +452,13 @@ class TensorCoreIntrinEmitter:
for j in T.serial(local_size_b):
mi, mk = mma_load_layout(tx, j)
if b_transposed:
B_local_buf[i * local_size_b + j] = B_buf[B_base0 + wi + mi,
B_base1 + wk + mk]
B_local_buf[i * local_size_b + j] = B_buf[B_base0 + wi + mi, B_base1 + wk + mk]
else:
B_local_buf[i * local_size_b + j] = B_buf[B_base0 + wk + mk,
B_base1 + wi + mi]
B_local_buf[i * local_size_b + j] = B_buf[B_base0 + wk + mk, B_base1 + wi + mi]
return _warp_ldmatrix_b(B_local_buf, B_shared_buf, ki, thread_binding, rk)
def mma(self,
A_local_buf: Buffer,
B_local_buf: Buffer,
C_local_buf: Buffer,
k_inner: PrimExpr | None = 0):
def mma(self, A_local_buf: Buffer, B_local_buf: Buffer, C_local_buf: Buffer, k_inner: PrimExpr | None = 0):
warp_rows = self.warp_rows
warp_cols = self.warp_cols
local_size_a = self.local_size_a
......@@ -492,7 +469,7 @@ class TensorCoreIntrinEmitter:
accum_dtype = self.accum_dtype
accum_dtype_abbrv = self.accum_dtype_abbrv
mma_prefix = self.mma_prefix
replicate_b = (self.n_dim == 16)
replicate_b = self.n_dim == 16
a_is_fragment = is_fragment(A_local_buf)
b_is_fragment = is_fragment(B_local_buf)
......@@ -532,8 +509,7 @@ class TensorCoreIntrinEmitter:
B_local_buf.data,
b_local_stride + j * local_size_b + lift(local_size_b) // 2,
C_local_buf.data,
i * warp_cols * local_size_out + j * local_size_out +
lift(local_size_out) // 2,
i * warp_cols * local_size_out + j * local_size_out + lift(local_size_out) // 2,
T.bool(False), # saturate
)
......@@ -568,14 +544,13 @@ class TensorCoreIntrinEmitter:
local_id = local_id_o * 2 + local_id_i
row, col = T.meta_var(mma_store_index_map(tx, local_id))
if C_buf_dims == 2:
C_buf[(warp_m * warp_rows + i) * M_DIM + row,
(warp_n * warp_cols + j) * n_dim +
col] = C_local_buf[i * (warp_cols * local_size_out) +
j * local_size_out + local_id]
C_buf[(warp_m * warp_rows + i) * M_DIM + row, (warp_n * warp_cols + j) * n_dim + col] = C_local_buf[
i * (warp_cols * local_size_out) + j * local_size_out + local_id
]
else:
C_buf[warp_m * warp_rows + i, warp_n * warp_cols + j, row,
col] = C_local_buf[i * (warp_cols * local_size_out) +
j * local_size_out + local_id]
C_buf[warp_m * warp_rows + i, warp_n * warp_cols + j, row, col] = C_local_buf[
i * (warp_cols * local_size_out) + j * local_size_out + local_id
]
@T.macro
def _warp_stmatrix_global(C_local_buf, C_buf, thread_binding):
......@@ -588,15 +563,15 @@ class TensorCoreIntrinEmitter:
C_buf[
(pid_m * BLOCK_M + warp_m * warp_rows + i) * M_DIM + row,
(pid_n * BLOCK_N + warp_n * warp_cols + j) * n_dim + col,
] = C_local_buf[i * warp_cols * local_size_out + j * local_size_out +
local_id]
] = C_local_buf[i * warp_cols * local_size_out + j * local_size_out + local_id]
return (_warp_stmatrix_global(C_local_buf, C_buf, thread_binding)
if is_global else _warp_stmatrix_shared(C_local_buf, C_buf, thread_binding))
return (
_warp_stmatrix_global(C_local_buf, C_buf, thread_binding)
if is_global
else _warp_stmatrix_shared(C_local_buf, C_buf, thread_binding)
)
def make_mma_load_layout(self,
local_buf: Buffer,
matrix: Literal["A", "B"] = "A") -> T.Fragment:
def make_mma_load_layout(self, local_buf: Buffer, matrix: Literal["A", "B"] = "A") -> T.Fragment:
"""
Create a layout function for storing MMA results into a fragment buffer.
This layout is used in conjunction with `inverse_mma_store_layout` to
......@@ -619,6 +594,7 @@ class TensorCoreIntrinEmitter:
If `local_buf` is not detected to be a fragment buffer.
"""
from tilelang.utils import is_fragment
assert matrix in ["A", "B"], "matrix should be either A or B"
matrix_is_a: bool = matrix == "A"
matrix_is_b: bool = matrix == "B"
......@@ -655,11 +631,9 @@ class TensorCoreIntrinEmitter:
# so the b matrix expected a transposed basic layout
transform_func: Callable = None
if matrix_is_a:
transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a(
j, i)
transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a(j, i)
elif matrix_is_b:
transform_func = transform_func_sr_b if is_sr_axis_order else lambda i, j: transform_func_sr_b(
j, i)
transform_func = transform_func_sr_b if is_sr_axis_order else lambda i, j: transform_func_sr_b(j, i)
else:
raise ValueError(f"Unsupported matrix {matrix}")
......@@ -706,31 +680,19 @@ class TensorCoreIntrinEmitter:
replicate = block_col_warps if matrix_is_a else block_row_warps
if is_sr_axis_order:
warp_fragment = base_fragment.repeat([warp_s, warp_r],
repeat_on_thread=False,
lower_dim_first=False)
warp_fragment = base_fragment.repeat([warp_s, warp_r], repeat_on_thread=False, lower_dim_first=False)
if matrix_is_a:
block_fragment = warp_fragment.repeat([block_s, 1],
repeat_on_thread=True,
lower_dim_first=True).replicate(replicate)
block_fragment = warp_fragment.repeat([block_s, 1], repeat_on_thread=True, lower_dim_first=True).replicate(replicate)
elif matrix_is_b:
block_fragment = warp_fragment.replicate(replicate).repeat([block_s, 1],
repeat_on_thread=True,
lower_dim_first=True)
block_fragment = warp_fragment.replicate(replicate).repeat([block_s, 1], repeat_on_thread=True, lower_dim_first=True)
else:
raise ValueError(f"Unsupported matrix type {matrix}")
else:
warp_fragment = base_fragment.repeat([warp_r, warp_s],
repeat_on_thread=False,
lower_dim_first=True)
warp_fragment = base_fragment.repeat([warp_r, warp_s], repeat_on_thread=False, lower_dim_first=True)
if matrix_is_a:
block_fragment = warp_fragment.repeat([1, block_s],
repeat_on_thread=True,
lower_dim_first=True).replicate(replicate)
block_fragment = warp_fragment.repeat([1, block_s], repeat_on_thread=True, lower_dim_first=True).replicate(replicate)
elif matrix_is_b:
block_fragment = warp_fragment.replicate(replicate).repeat([1, block_s],
repeat_on_thread=True,
lower_dim_first=True)
block_fragment = warp_fragment.replicate(replicate).repeat([1, block_s], repeat_on_thread=True, lower_dim_first=True)
else:
raise ValueError(f"Unsupported matrix type {matrix}")
......@@ -761,8 +723,7 @@ class TensorCoreIntrinEmitter:
from tilelang.utils import is_fragment
shape = local_buf.shape
assert is_fragment(
local_buf), f"local_buf {local_buf} must be a fragment, but got {local_buf.scope()}"
assert is_fragment(local_buf), f"local_buf {local_buf} must be a fragment, but got {local_buf.scope()}"
inverse_mma_store_layout = self.get_store_index_map(inverse=True)
micro_size_x, micro_size_y = self.micro_size_x, self.micro_size_y
......@@ -954,10 +915,12 @@ class TensorCoreIntrinEmitterWithLadderTransform(TensorCoreIntrinEmitter):
".b16",
A_local_buf.data,
i * local_size_a,
T.address_of(A_shared_buf[
T.address_of(
A_shared_buf[
warp_m * warp_row_tiles + i * micro_size_x,
rk * chunk + ki * micro_size_k,
]),
]
),
get_ldmatrix_offset("A", tx, 0, stride, a_dtype, a_transposed),
)
elif transform_kind_a == TransformKind.InterWarpTransform:
......@@ -1019,10 +982,8 @@ class TensorCoreIntrinEmitterWithLadderTransform(TensorCoreIntrinEmitter):
warp_m * warp_rows + j,
rk * (chunk // micro_size_k) + ki,
)
rii, rjj = (tx * local_size_a +
local_id) // micro_size_k, (tx * local_size_a + local_id) % (
micro_size_k)
A_local_buf[j * local_size_a + local_id] = (A_shared_buf[ri, rj, rii, rjj])
rii, rjj = (tx * local_size_a + local_id) // micro_size_k, (tx * local_size_a + local_id) % (micro_size_k)
A_local_buf[j * local_size_a + local_id] = A_shared_buf[ri, rj, rii, rjj]
else:
raise ValueError("Unsupported TransformKind for Input A")
......@@ -1131,12 +1092,11 @@ class TensorCoreIntrinEmitterWithLadderTransform(TensorCoreIntrinEmitter):
warp_n * warp_cols + j,
rk * (chunk // micro_size_k) + ki,
)
rii, rjj = (tx * local_size_dequantize +
local_id) // (micro_size_k // num_elems_per_byte), (
tx * local_size_dequantize + local_id) % (
micro_size_k // num_elems_per_byte)
B_local_buf[j * local_size_dequantize + local_id] = (
B_shared_buf[ri, rj, rii, rjj])
rii, rjj = (
(tx * local_size_dequantize + local_id) // (micro_size_k // num_elems_per_byte),
(tx * local_size_dequantize + local_id) % (micro_size_k // num_elems_per_byte),
)
B_local_buf[j * local_size_dequantize + local_id] = B_shared_buf[ri, rj, rii, rjj]
else:
raise ValueError("Unsupported TransformKind for Input B")
......@@ -1195,7 +1155,6 @@ class TensorCoreIntrinEmitterWithLadderTransform(TensorCoreIntrinEmitter):
class INT4TensorCoreIntrinEmitter(TensorCoreIntrinEmitter):
def mma(self, A_local_buf, B_local_buf, C_local_buf):
warp_rows = self.warp_rows
warp_cols = self.warp_cols
......@@ -1298,9 +1257,7 @@ class INT4TensorCoreIntrinEmitter(TensorCoreIntrinEmitter):
class INT4TensorCoreIntrinEmitterWithLadderTransform(TensorCoreIntrinEmitterWithLadderTransform):
def mma(self, A_local_buf, B_local_buf, C_local_buf):
warp_rows = self.warp_rows
warp_cols = self.warp_cols
local_size_a = self.local_size_a
......
......@@ -17,10 +17,8 @@ def shared_16x4_to_mma_b_32x4_layout_trans(row, col, rep):
def mma_32x8_to_shared_16x16_layout_fp32(thread_id, local_id):
row = (thread_id % 2) + (
(local_id // 2 % 2) * 2) + 4 * (thread_id // 16) + (thread_id % 16 // 4) % 2 * 8
col = (thread_id % 4 // 2) * 2 + (thread_id % 16 // 8) * 4 + (local_id %
2) + (local_id // 4) * 8
row = (thread_id % 2) + ((local_id // 2 % 2) * 2) + 4 * (thread_id // 16) + (thread_id % 16 // 4) % 2 * 8
col = (thread_id % 4 // 2) * 2 + (thread_id % 16 // 8) * 4 + (local_id % 2) + (local_id // 4) * 8
return row, col
......@@ -31,7 +29,7 @@ def mma_32x8_to_shared_16x16_layout_fp16(thread_id, local_id):
def mma_load_a_32x4_to_shared_16x4_layout(thread_id, local_id):
row = (thread_id % 4) + (4 * (((thread_id // 16 + thread_id % 16 // 4 * 2)) % 4))
row = (thread_id % 4) + (4 * ((thread_id // 16 + thread_id % 16 // 4 * 2) % 4))
col = local_id
return row, col
......
......@@ -147,18 +147,15 @@ class TensorCoreIntrinEmitter:
def get_store_index_map(self, inverse: bool = False) -> IndexMap:
warp_size, local_size_c = self.WARP_SIZE, self.local_size_out
index_map = IndexMap.from_func(
mma_32x8_to_shared_16x16_layout_fp32
if self.accum_dtype == "float32" else mma_32x8_to_shared_16x16_layout_fp16,
index_dtype="int32")
mma_32x8_to_shared_16x16_layout_fp32 if self.accum_dtype == "float32" else mma_32x8_to_shared_16x16_layout_fp16,
index_dtype="int32",
)
if not inverse:
return index_map
inverse_index_map = index_map.inverse([warp_size, local_size_c])
return inverse_index_map
def extract_thread_binding(
self,
thread_id: PrimExpr,
is_m_first: bool | None = None) -> tuple[PrimExpr, PrimExpr, PrimExpr]:
def extract_thread_binding(self, thread_id: PrimExpr, is_m_first: bool | None = None) -> tuple[PrimExpr, PrimExpr, PrimExpr]:
"""
is_m_first: True if the thread binding is in the form of (tx, warp_n, warp_m)
which represents [warp_size, block_row_warps (split n), block_col_warps (split m)]
......@@ -187,11 +184,7 @@ class TensorCoreIntrinEmitter:
)
return lane_id, warp_n, warp_m
def ldmatrix_a(self,
A_local_buf: Buffer,
A_shared_buf: Buffer | BufferRegion,
ki: PrimExpr,
rk: PrimExpr | None = 0):
def ldmatrix_a(self, A_local_buf: Buffer, A_shared_buf: Buffer | BufferRegion, ki: PrimExpr, rk: PrimExpr | None = 0):
warp_row_tiles = self.warp_row_tiles
warp_rows = self.warp_rows
chunk = self.chunk
......@@ -231,11 +224,7 @@ class TensorCoreIntrinEmitter:
return _warp_ldmatrix_a(A_local_buf, A_region, ki, thread_binding, rk)
def ldmatrix_b(self,
B_local_buf: Buffer,
B_shared_buf: Buffer | BufferRegion,
ki: PrimExpr,
rk: PrimExpr | None = 0):
def ldmatrix_b(self, B_local_buf: Buffer, B_shared_buf: Buffer | BufferRegion, ki: PrimExpr, rk: PrimExpr | None = 0):
warp_col_tiles = self.warp_col_tiles
warp_cols = self.warp_cols
chunk = self.chunk
......@@ -274,20 +263,14 @@ class TensorCoreIntrinEmitter:
for j in T.vectorized(local_size_b):
if b_transposed:
mi, mk = mma_load_layout(tx, j)
B_local_buf[i * local_size_b + j] = B_buf[B_base0 + wi + mi,
B_base1 + wk + mk]
B_local_buf[i * local_size_b + j] = B_buf[B_base0 + wi + mi, B_base1 + wk + mk]
else:
mk, mi = mma_load_layout(tx, j)
B_local_buf[i * local_size_b + j] = B_buf[B_base0 + wk + mk,
B_base1 + wi + mi]
B_local_buf[i * local_size_b + j] = B_buf[B_base0 + wk + mk, B_base1 + wi + mi]
return _warp_ldmatrix_b(B_local_buf, B_region, ki, thread_binding, rk)
def mma(self,
A_local_buf: Buffer,
B_local_buf: Buffer,
C_local_buf: Buffer,
k_inner: PrimExpr | None = 0):
def mma(self, A_local_buf: Buffer, B_local_buf: Buffer, C_local_buf: Buffer, k_inner: PrimExpr | None = 0):
warp_rows = self.warp_rows
warp_cols = self.warp_cols
local_size_a = self.local_size_a
......@@ -326,9 +309,7 @@ class TensorCoreIntrinEmitter:
return _warp_mma(A_local_buf, B_local_buf, C_local_buf)
def make_mma_load_layout(self,
local_buf: Buffer,
matrix: Literal["A", "B"] = "A") -> T.Fragment:
def make_mma_load_layout(self, local_buf: Buffer, matrix: Literal["A", "B"] = "A") -> T.Fragment:
"""
Create a layout function for storing MMA results into a fragment buffer.
This layout is used in conjunction with `inverse_mma_store_layout` to
......@@ -351,6 +332,7 @@ class TensorCoreIntrinEmitter:
If `local_buf` is not detected to be a fragment buffer.
"""
from tilelang.utils import is_fragment
assert matrix in ["A", "B"], "matrix should be either A or B"
matrix_is_a: bool = matrix == "A"
matrix_is_b: bool = matrix == "B"
......@@ -383,11 +365,9 @@ class TensorCoreIntrinEmitter:
# so the b matrix expected a transposed basic layout
transform_func: Callable = None
if matrix_is_a:
transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a(
j, i)
transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a(j, i)
elif matrix_is_b:
transform_func = transform_func_sr_b if is_sr_axis_order else lambda i, j: transform_func_rs_b(
i, j)
transform_func = transform_func_sr_b if is_sr_axis_order else lambda i, j: transform_func_rs_b(i, j)
else:
raise ValueError(f"Unsupported matrix {matrix}")
......@@ -413,9 +393,8 @@ class TensorCoreIntrinEmitter:
return lane_id, local_id
base_fragment = T.Fragment(
[micro_size_s, micro_size_r] if is_sr_axis_order else [micro_size_r, micro_size_s],
forward_fn=forward,
replicate=2)
[micro_size_s, micro_size_r] if is_sr_axis_order else [micro_size_r, micro_size_s], forward_fn=forward, replicate=2
)
warp_rows, warp_cols = self.warp_rows, self.warp_cols
chunk = self.chunk
......@@ -426,31 +405,19 @@ class TensorCoreIntrinEmitter:
replicate = block_col_warps if matrix_is_a else block_row_warps
if is_sr_axis_order:
warp_fragment = base_fragment.repeat([warp_s, warp_r],
repeat_on_thread=False,
lower_dim_first=False)
warp_fragment = base_fragment.repeat([warp_s, warp_r], repeat_on_thread=False, lower_dim_first=False)
if matrix_is_a:
block_fragment = warp_fragment.repeat([block_s, 1],
repeat_on_thread=True,
lower_dim_first=True).replicate(replicate)
block_fragment = warp_fragment.repeat([block_s, 1], repeat_on_thread=True, lower_dim_first=True).replicate(replicate)
elif matrix_is_b:
block_fragment = warp_fragment.replicate(replicate).repeat([block_s, 1],
repeat_on_thread=True,
lower_dim_first=True)
block_fragment = warp_fragment.replicate(replicate).repeat([block_s, 1], repeat_on_thread=True, lower_dim_first=True)
else:
raise ValueError(f"Unsupported matrix type {matrix}")
else:
warp_fragment = base_fragment.repeat([warp_r, warp_s],
repeat_on_thread=False,
lower_dim_first=True)
warp_fragment = base_fragment.repeat([warp_r, warp_s], repeat_on_thread=False, lower_dim_first=True)
if matrix_is_a:
block_fragment = warp_fragment.repeat([1, block_s],
repeat_on_thread=True,
lower_dim_first=True).replicate(replicate)
block_fragment = warp_fragment.repeat([1, block_s], repeat_on_thread=True, lower_dim_first=True).replicate(replicate)
elif matrix_is_b:
block_fragment = warp_fragment.replicate(replicate).repeat([1, block_s],
repeat_on_thread=True,
lower_dim_first=True)
block_fragment = warp_fragment.replicate(replicate).repeat([1, block_s], repeat_on_thread=True, lower_dim_first=True)
else:
raise ValueError(f"Unsupported matrix type {matrix}")
......
......@@ -72,56 +72,47 @@ def get_logical_id_32bit(thread_id: int) -> int:
return (thread_id // 4) * 2 + (thread_id % 4) % 2
def metadata_8bit_load_32x4_to_shared_16x4_layout_32bit(thread_id: int,
local_id: int) -> tuple[int, int]:
def metadata_8bit_load_32x4_to_shared_16x4_layout_32bit(thread_id: int, local_id: int) -> tuple[int, int]:
logical_id = get_logical_id_32bit(thread_id)
row = logical_id // 4 + local_id * 8
col = logical_id % 4
return row, col
def metadata_16bit_load_32x2_to_shared_16x2_layout_32bit(thread_id: int,
local_id: int) -> tuple[int, int]:
def metadata_16bit_load_32x2_to_shared_16x2_layout_32bit(thread_id: int, local_id: int) -> tuple[int, int]:
logical_id = get_logical_id_32bit(thread_id)
row = logical_id // 2 + local_id * 8
col = logical_id % 2
return row, col
def metadata_8bit_load_32x4_to_shared_16x4_layout_16bit(thread_id: int,
local_id: int) -> tuple[int, int]:
return metadata_8bit_load_32x4_to_shared_16x4_layout_32bit(
thread_id, local_id) # same mapping for 16bit and 32bit
def metadata_8bit_load_32x4_to_shared_16x4_layout_16bit(thread_id: int, local_id: int) -> tuple[int, int]:
return metadata_8bit_load_32x4_to_shared_16x4_layout_32bit(thread_id, local_id) # same mapping for 16bit and 32bit
def metadata_16bit_load_32x2_to_shared_16x2_layout_16bit(thread_id: int,
local_id: int) -> tuple[int, int]:
return metadata_16bit_load_32x2_to_shared_16x2_layout_32bit(
thread_id, local_id) # same mapping for 16bit and 32bit
def metadata_16bit_load_32x2_to_shared_16x2_layout_16bit(thread_id: int, local_id: int) -> tuple[int, int]:
return metadata_16bit_load_32x2_to_shared_16x2_layout_32bit(thread_id, local_id) # same mapping for 16bit and 32bit
def get_logical_id_8bit(thread_id: int) -> int:
return thread_id
def metadata_8bit_load_32x4_to_shared_16x4_layout_8bit(thread_id: int,
local_id: int) -> tuple[int, int]:
def metadata_8bit_load_32x4_to_shared_16x4_layout_8bit(thread_id: int, local_id: int) -> tuple[int, int]:
logical_id = get_logical_id_8bit(thread_id)
row = logical_id // 2 + local_id * 8
col = (logical_id % 4) // 2 * 4 + local_id
return row, col
def metadata_16bit_load_32x2_to_shared_16x4_layout_8bit(thread_id: int,
local_id: int) -> tuple[int, int]:
def metadata_16bit_load_32x2_to_shared_16x4_layout_8bit(thread_id: int, local_id: int) -> tuple[int, int]:
logical_id = get_logical_id_8bit(thread_id)
row = logical_id // 2 + local_id * 8
col = (logical_id % 4) // 2 * 2 + local_id
return row, col
def metadata_32bit_load_32x1_to_shared_16x2_layout_8bit(thread_id: int,
local_id: int) -> tuple[int, int]:
def metadata_32bit_load_32x1_to_shared_16x2_layout_8bit(thread_id: int, local_id: int) -> tuple[int, int]:
# local_id is always 0
logical_id = get_logical_id_8bit(thread_id)
row = logical_id // 4 + (logical_id % 2) * 8
......
......@@ -190,8 +190,7 @@ class SparseTensorCoreIntrinEmitter:
def _initialize_local_size(self, m_dim=16, n_dim=16, k_dim=16, warp_size=32):
self.local_size_a = (m_dim * k_dim) // warp_size // self.SPARSE_FACTOR
self.local_size_e = (
m_dim * k_dim) // self.e_factor // warp_size * self.E_REPLICATE_FACTOR[self.a_dtype]
self.local_size_e = (m_dim * k_dim) // self.e_factor // warp_size * self.E_REPLICATE_FACTOR[self.a_dtype]
self.local_size_b = (n_dim * k_dim) // warp_size
self.local_size_out = (m_dim * n_dim) // warp_size
......@@ -257,10 +256,7 @@ class SparseTensorCoreIntrinEmitter:
inverse_index_map = index_map.inverse([warp_size, local_size_c])
return inverse_index_map
def extract_thread_binding(
self,
thread_id: PrimExpr,
is_m_first: bool | None = None) -> tuple[PrimExpr, PrimExpr, PrimExpr]:
def extract_thread_binding(self, thread_id: PrimExpr, is_m_first: bool | None = None) -> tuple[PrimExpr, PrimExpr, PrimExpr]:
"""
is_m_first: True if the thread binding is in the form of (tx, warp_n, warp_m)
which represents [warp_size, block_row_warps (split n), block_col_warps (split m)]
......@@ -330,8 +326,7 @@ class SparseTensorCoreIntrinEmitter:
for i in T.serial(warp_rows):
# Assign A_shared_buf_elem
wi, wk = warp_m * warp_row_tiles + i * micro_size_x, (
rk * warp_k + ki * micro_size_k) // self.SPARSE_FACTOR
wi, wk = warp_m * warp_row_tiles + i * micro_size_x, (rk * warp_k + ki * micro_size_k) // self.SPARSE_FACTOR
A_shared_buf_elem = A_shared_buf[wk, wi] if a_transposed else A_shared_buf[wi, wk]
if ldmatrix_available:
......@@ -348,10 +343,9 @@ class SparseTensorCoreIntrinEmitter:
else:
for j in T.serial(local_size_a):
mi, mk = mma_load_layout(tx, j)
A_local_buf[i * local_size_a +
j] = A_shared_buf[wk + mk, wi +
mi] if a_transposed else A_shared_buf[wi + mi,
wk + mk]
A_local_buf[i * local_size_a + j] = (
A_shared_buf[wk + mk, wi + mi] if a_transposed else A_shared_buf[wi + mi, wk + mk]
)
return _warp_ldmatrix_a(A_local_buf, A_shared_buf, ki, thread_binding, rk)
......@@ -412,14 +406,10 @@ class SparseTensorCoreIntrinEmitter:
tx, _, warp_m = self.extract_thread_binding(thread_binding)
for i in T.serial(warp_rows):
# Assign E_shared_buf_elem
wi, wk = warp_m * warp_row_tiles + i * micro_size_x, (
rk * warp_k + ki * micro_size_k) // self.e_factor
wi, wk = warp_m * warp_row_tiles + i * micro_size_x, (rk * warp_k + ki * micro_size_k) // self.e_factor
for j in T.serial(local_size_e):
mi, mk = mma_load_layout(tx, j)
E_local_buf[i * local_size_e +
j] = E_shared_buf[wk + mk,
wi + mi] if trans else E_shared_buf[wi + mi,
wk + mk]
E_local_buf[i * local_size_e + j] = E_shared_buf[wk + mk, wi + mi] if trans else E_shared_buf[wi + mi, wk + mk]
return _warp_ldmatrix_e(E_local_buf, E_shared_buf, ki, thread_binding, rk)
......@@ -433,7 +423,7 @@ class SparseTensorCoreIntrinEmitter:
b_dtype = self.b_dtype
b_transposed = self.b_transposed
thread_binding = self.get_thread_binding()
replicate_b = (self.n_dim == 16)
replicate_b = self.n_dim == 16
# ldmatrix cannot be used for int8 + trans case.
ldmatrix_available = not (DataType(b_dtype).bits != 16 and not b_transposed)
......@@ -470,8 +460,7 @@ class SparseTensorCoreIntrinEmitter:
)
if ldmatrix_available:
B_shared_buf_elem = B_shared_buf[wi, wk] if b_transposed else B_shared_buf[wk,
wi]
B_shared_buf_elem = B_shared_buf[wi, wk] if b_transposed else B_shared_buf[wk, wi]
if replicate_b:
T.ptx_ldmatrix(
......@@ -493,9 +482,7 @@ class SparseTensorCoreIntrinEmitter:
B_local_buf.data,
i * local_size_b + lift(local_size_b) // 2,
T.address_of(B_shared_buf_elem),
get_ldmatrix_offset_b("B", tx,
lift(local_size_b) // 2, stride, b_dtype,
b_transposed),
get_ldmatrix_offset_b("B", tx, lift(local_size_b) // 2, stride, b_dtype, b_transposed),
)
else:
T.ptx_ldmatrix(
......@@ -514,19 +501,13 @@ class SparseTensorCoreIntrinEmitter:
# must be transposed.
for j in T.serial(local_size_b):
mi, mk = mma_load_layout(tx, j)
B_local_buf[i * local_size_b +
j] = B_shared_buf[wi + mi, wk +
mk] if b_transposed else B_shared_buf[wk + mk,
wi + mi]
B_local_buf[i * local_size_b + j] = (
B_shared_buf[wi + mi, wk + mk] if b_transposed else B_shared_buf[wk + mk, wi + mi]
)
return _warp_ldmatrix_b(B_local_buf, B_shared_buf, ki, thread_binding, rk)
def mma_sp(self,
A_local_buf: Buffer,
E_local_buf: Buffer,
B_local_buf: Buffer,
C_local_buf: Buffer,
k_inner: PrimExpr = 0):
def mma_sp(self, A_local_buf: Buffer, E_local_buf: Buffer, B_local_buf: Buffer, C_local_buf: Buffer, k_inner: PrimExpr = 0):
warp_rows = self.warp_rows
warp_cols = self.warp_cols
local_size_a = self.local_size_a
......@@ -538,7 +519,7 @@ class SparseTensorCoreIntrinEmitter:
accum_dtype = self.accum_dtype
accum_dtype_abbrv = self.accum_dtype_abbrv
mma_prefix = self.mma_prefix
replicate_b = (self.n_dim == 16)
replicate_b = self.n_dim == 16
a_is_fragment = is_fragment(A_local_buf)
e_is_fragment = is_fragment(E_local_buf)
......@@ -584,8 +565,7 @@ class SparseTensorCoreIntrinEmitter:
B_local_buf.data,
b_local_stride + j * local_size_b + lift(local_size_b) // 2,
C_local_buf.data,
i * warp_cols * local_size_out + j * local_size_out +
lift(local_size_out) // 2,
i * warp_cols * local_size_out + j * local_size_out + lift(local_size_out) // 2,
E_local_buf.data, # metadata
e_local_stride + i * local_size_e, # metadata offset
self.SPARSE_SELECTOR, # sparse_selector
......@@ -623,14 +603,13 @@ class SparseTensorCoreIntrinEmitter:
local_id = local_id_o * 2 + local_id_i
row, col = T.meta_var(mma_store_index_map(tx, local_id))
if C_buf_dims == 2:
C_buf[(warp_m * warp_rows + i) * M_DIM + row,
(warp_n * warp_cols + j) * n_dim +
col] = C_local_buf[i * (warp_cols * local_size_out) +
j * local_size_out + local_id]
C_buf[(warp_m * warp_rows + i) * M_DIM + row, (warp_n * warp_cols + j) * n_dim + col] = C_local_buf[
i * (warp_cols * local_size_out) + j * local_size_out + local_id
]
else:
C_buf[warp_m * warp_rows + i, warp_n * warp_cols + j, row,
col] = C_local_buf[i * (warp_cols * local_size_out) +
j * local_size_out + local_id]
C_buf[warp_m * warp_rows + i, warp_n * warp_cols + j, row, col] = C_local_buf[
i * (warp_cols * local_size_out) + j * local_size_out + local_id
]
@T.macro
def _warp_stmatrix_global(C_local_buf, C_buf, thread_binding):
......@@ -643,15 +622,15 @@ class SparseTensorCoreIntrinEmitter:
C_buf[
(pid_m * BLOCK_M + warp_m * warp_rows + i) * M_DIM + row,
(pid_n * BLOCK_N + warp_n * warp_cols + j) * n_dim + col,
] = C_local_buf[i * warp_cols * local_size_out + j * local_size_out +
local_id]
] = C_local_buf[i * warp_cols * local_size_out + j * local_size_out + local_id]
return (_warp_stmatrix_global(C_local_buf, C_buf, thread_binding)
if is_global else _warp_stmatrix_shared(C_local_buf, C_buf, thread_binding))
return (
_warp_stmatrix_global(C_local_buf, C_buf, thread_binding)
if is_global
else _warp_stmatrix_shared(C_local_buf, C_buf, thread_binding)
)
def make_mma_load_layout(self,
local_buf: Buffer,
matrix: Literal["A", "B"] = "A") -> T.Fragment:
def make_mma_load_layout(self, local_buf: Buffer, matrix: Literal["A", "B"] = "A") -> T.Fragment:
"""
Create a layout function for storing MMA results into a fragment buffer.
This layout is used in conjunction with `inverse_mma_store_layout` to
......@@ -674,6 +653,7 @@ class SparseTensorCoreIntrinEmitter:
If `local_buf` is not detected to be a fragment buffer.
"""
from tilelang.utils import is_fragment
assert matrix in ["A", "B"], "matrix should be either A or B"
matrix_is_a: bool = matrix == "A"
matrix_is_b: bool = matrix == "B"
......@@ -710,11 +690,9 @@ class SparseTensorCoreIntrinEmitter:
# so the b matrix expected a transposed basic layout
transform_func: Callable = None
if matrix_is_a:
transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a(
j, i)
transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a(j, i)
elif matrix_is_b:
transform_func = transform_func_sr_b if is_sr_axis_order else lambda i, j: transform_func_sr_b(
j, i)
transform_func = transform_func_sr_b if is_sr_axis_order else lambda i, j: transform_func_sr_b(j, i)
else:
raise ValueError(f"Unsupported matrix {matrix}")
......@@ -747,7 +725,8 @@ class SparseTensorCoreIntrinEmitter:
return local_id
base_fragment = T.Fragment(
[micro_size_s, micro_size_r // 2 if matrix_is_a else micro_size_r] if is_sr_axis_order
[micro_size_s, micro_size_r // 2 if matrix_is_a else micro_size_r]
if is_sr_axis_order
else [micro_size_r // 2 if matrix_is_a else micro_size_r, micro_size_s],
forward_thread_fn=forward_thread,
forward_index_fn=forward_index,
......@@ -762,31 +741,19 @@ class SparseTensorCoreIntrinEmitter:
replicate = block_col_warps if matrix_is_a else block_row_warps
if is_sr_axis_order:
warp_fragment = base_fragment.repeat([warp_s, warp_r],
repeat_on_thread=False,
lower_dim_first=False)
warp_fragment = base_fragment.repeat([warp_s, warp_r], repeat_on_thread=False, lower_dim_first=False)
if matrix_is_a:
block_fragment = warp_fragment.repeat([block_s, 1],
repeat_on_thread=True,
lower_dim_first=True).replicate(replicate)
block_fragment = warp_fragment.repeat([block_s, 1], repeat_on_thread=True, lower_dim_first=True).replicate(replicate)
elif matrix_is_b:
block_fragment = warp_fragment.replicate(replicate).repeat([block_s, 1],
repeat_on_thread=True,
lower_dim_first=True)
block_fragment = warp_fragment.replicate(replicate).repeat([block_s, 1], repeat_on_thread=True, lower_dim_first=True)
else:
raise ValueError(f"Unsupported matrix type {matrix}")
else:
warp_fragment = base_fragment.repeat([warp_r, warp_s],
repeat_on_thread=False,
lower_dim_first=True)
warp_fragment = base_fragment.repeat([warp_r, warp_s], repeat_on_thread=False, lower_dim_first=True)
if matrix_is_a:
block_fragment = warp_fragment.repeat([1, block_s],
repeat_on_thread=True,
lower_dim_first=True).replicate(replicate)
block_fragment = warp_fragment.repeat([1, block_s], repeat_on_thread=True, lower_dim_first=True).replicate(replicate)
elif matrix_is_b:
block_fragment = warp_fragment.replicate(replicate).repeat([1, block_s],
repeat_on_thread=True,
lower_dim_first=True)
block_fragment = warp_fragment.replicate(replicate).repeat([1, block_s], repeat_on_thread=True, lower_dim_first=True)
else:
raise ValueError(f"Unsupported matrix type {matrix}")
......
......@@ -88,9 +88,22 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
is_m_first: bool = False,
thread_var: Var | None = None,
):
super().__init__(a_dtype, b_dtype, accum_dtype, a_transposed, b_transposed, block_row_warps,
block_col_warps, warp_row_tiles, warp_col_tiles, chunk, reduce_k,
num_elems_per_byte, is_m_first, thread_var)
super().__init__(
a_dtype,
b_dtype,
accum_dtype,
a_transposed,
b_transposed,
block_row_warps,
block_col_warps,
warp_row_tiles,
warp_col_tiles,
chunk,
reduce_k,
num_elems_per_byte,
is_m_first,
thread_var,
)
def _assign_a_shared_layout(self, layout: Layout):
self.a_shared_layout = layout
......@@ -137,13 +150,7 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
else:
raise ValueError(f"Unsupported swizzle mode: {layout}")
def tcgen05mma(self,
A_buf: Buffer,
B_buf: Buffer,
C_local_buf: Buffer,
mbar,
clear_accum: PrimExpr = False):
def tcgen05mma(self, A_buf: Buffer, B_buf: Buffer, C_local_buf: Buffer, mbar, clear_accum: PrimExpr = False):
if is_tensor_memory(A_buf):
return self.tcgen05mma_rs(A_buf, B_buf, C_local_buf, clear_accum)
......@@ -164,22 +171,20 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
elems_in_bits = DataType(self.a_dtype).bits
elems_in_bytes = elems_in_bits // 8
a_swizzle_atom_elems = a_swizzle_mode.swizzle_byte_size() // elems_in_bytes
b_swizzle_atom_elems = n_dim if b_swizzle_mode.is_none(
) else b_swizzle_mode.swizzle_byte_size() // elems_in_bytes
b_swizzle_atom_elems = n_dim if b_swizzle_mode.is_none() else b_swizzle_mode.swizzle_byte_size() // elems_in_bytes
accum_dtype_in_bits = DataType(accum_dtype).bits
meta = self.get_tcgen5_mma_meta(m_dim, n_dim, k_dim)
if len(meta) != 5:
raise ValueError(
f"Unsupported TCGEN5MMA configuration for desc generation: M={m_dim}, N={n_dim}, "
f"K={k_dim}, A dtype={self.a_dtype}, accum dtype={self.accum_dtype}")
f"K={k_dim}, A dtype={self.a_dtype}, accum dtype={self.accum_dtype}"
)
atom_m, atom_n, atom_k, enable_ws, enable_2cta = (int(x) for x in meta)
# by default, we utilize non-swizzle layout offset
a_leading_byte_offset = (8 * 8 * elems_in_bytes) if a_is_k_major else (8 * m_dim *
elems_in_bytes)
a_stride_byte_offset = (8 * k_dim * elems_in_bytes) if a_is_k_major else (8 * 8 *
elems_in_bytes)
a_leading_byte_offset = (8 * 8 * elems_in_bytes) if a_is_k_major else (8 * m_dim * elems_in_bytes)
a_stride_byte_offset = (8 * k_dim * elems_in_bytes) if a_is_k_major else (8 * 8 * elems_in_bytes)
if not a_swizzle_mode.is_none():
# swizzle mode doesn't require LBO/SBO to be 1
......@@ -202,11 +207,8 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
else:
a_stride_byte_offset = 8 * elems_in_bytes * a_swizzle_atom_elems
b_leading_byte_offset = (8 * 8 * elems_in_bytes) if b_is_k_major else (8 * n_dim *
elems_in_bytes)
b_stride_byte_offset = (8 * k_dim *
elems_in_bytes) if b_is_k_major else (0 if n_dim == 8 else
(8 * 8 * elems_in_bytes))
b_leading_byte_offset = (8 * 8 * elems_in_bytes) if b_is_k_major else (8 * n_dim * elems_in_bytes)
b_stride_byte_offset = (8 * k_dim * elems_in_bytes) if b_is_k_major else (0 if n_dim == 8 else (8 * 8 * elems_in_bytes))
if not b_swizzle_mode.is_none():
# swizzle mode doesn't require LBO/SBO to be 1
# https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-leading-dimension-byte-offset
......@@ -312,21 +314,26 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
for ki in T.unroll(0, (k_dim // micro_size_k)):
scale_out = T.Select(ki != 0, 1, T.Select(clear_accum, 0, 1))
A_elem_offset = (
ki % ak_atom_size
) * micro_size_k + i * atom_m * a_swizzle_atom_elems + (
ki // ak_atom_size
) * m_dim * a_swizzle_atom_elems if a_is_k_major else i * atom_m * k_dim + ki * a_swizzle_atom_elems * micro_size_k
(ki % ak_atom_size) * micro_size_k
+ i * atom_m * a_swizzle_atom_elems
+ (ki // ak_atom_size) * m_dim * a_swizzle_atom_elems
if a_is_k_major
else i * atom_m * k_dim + ki * a_swizzle_atom_elems * micro_size_k
)
B_elem_offset = (ki // bk_atom_size) * n_dim * b_swizzle_atom_elems + (
ki % bk_atom_size
) * micro_size_k + j * atom_n * b_swizzle_atom_elems if b_is_k_major else (
ki * b_swizzle_atom_elems * micro_size_k + j * atom_n *
(k_dim if n_dim // b_swizzle_atom_elems > 1 else 1))
B_elem_offset = (
(ki // bk_atom_size) * n_dim * b_swizzle_atom_elems
+ (ki % bk_atom_size) * micro_size_k
+ j * atom_n * b_swizzle_atom_elems
if b_is_k_major
else (
ki * b_swizzle_atom_elems * micro_size_k + j * atom_n * (k_dim if n_dim // b_swizzle_atom_elems > 1 else 1)
)
)
A_byte_offset = A_elem_offset * elems_in_bytes
B_byte_offset = B_elem_offset * elems_in_bytes
C_offset = (i * n_dim + j * tmem_col_step
) * accum_dtype_in_bits // 32 # 32 bits per tmem bank
C_offset = (i * n_dim + j * tmem_col_step) * accum_dtype_in_bits // 32 # 32 bits per tmem bank
T.ptx_tcgen05_mma_ss(
a_dtype_abbrv,
......@@ -373,8 +380,7 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
"""
assert is_tensor_memory(tmem_buf), "tmem_buf must reside in tensor memory (shared.tmem)"
if len(tmem_buf.shape) != 2:
raise ValueError(
f"TCGEN5MMA expects a 2-D tensor-memory buffer, got shape {tmem_buf.shape}")
raise ValueError(f"TCGEN5MMA expects a 2-D tensor-memory buffer, got shape {tmem_buf.shape}")
m = int(tmem_buf.shape[0])
n = int(tmem_buf.shape[1])
......@@ -382,14 +388,13 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
meta = self.get_tcgen5_mma_meta(m, n, k)
if len(meta) != 5:
raise ValueError(f"Unsupported TCGEN5MMA configuration: M={m}, N={n}, K={k}, "
f"A dtype={self.a_dtype}, accum dtype={self.accum_dtype}")
raise ValueError(
f"Unsupported TCGEN5MMA configuration: M={m}, N={n}, K={k}, A dtype={self.a_dtype}, accum dtype={self.accum_dtype}"
)
atom_m, atom_n, _, _, _ = (int(x) for x in meta)
if m % atom_m != 0 or n % atom_n != 0:
raise ValueError(
f"Invalid TCGEN5MMA store layout for shape ({m}, {n}) with atoms ({atom_m}, {atom_n})"
)
raise ValueError(f"Invalid TCGEN5MMA store layout for shape ({m}, {n}) with atoms ({atom_m}, {atom_n})")
def forward(i: PrimExpr, j: PrimExpr):
atom_idx = (i // atom_m) + (j // atom_n) * (m // atom_m)
......@@ -422,11 +427,11 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
return Layout([m, n], forward)
def get_tcgen5_mma_meta(self, m: int, n: int, k: int):
return _ffi_api.get_tcgen5_mma_meta(
int(m), int(n), int(k), DataType(self.a_dtype), DataType(self.accum_dtype))
return _ffi_api.get_tcgen5_mma_meta(int(m), int(n), int(k), DataType(self.a_dtype), DataType(self.accum_dtype))
def get_tcgen5_instr_desc(self, atom_m: int, atom_n: int, atom_k: int, a_is_k_major: bool,
b_is_k_major: bool, scale_in_a: int, scale_in_b: int) -> PrimExpr:
def get_tcgen5_instr_desc(
self, atom_m: int, atom_n: int, atom_k: int, a_is_k_major: bool, b_is_k_major: bool, scale_in_a: int, scale_in_b: int
) -> PrimExpr:
desc = _ffi_api.get_tcgen5_instr_desc(
atom_m,
atom_n,
......
......@@ -10,7 +10,7 @@ from .mma_layout import (
mma_store_32x8_to_shared_16x16_layout,
mma_store_32x2_to_shared_8x8_layout_fp64,
)
from .mfma_layout import (thread_id_shared_access_64x4_to_16x16_layout_C_n_m)
from .mfma_layout import thread_id_shared_access_64x4_to_16x16_layout_C_n_m
from .mma_layout import get_swizzle_layout # noqa: F401
from .mma_layout import make_mma_swizzle_layout # noqa: F401
......
......@@ -15,9 +15,11 @@ from tilelang.layout import (
make_linear_layout,
)
from tvm.runtime import convert
from tilelang.intrinsics.mma_layout import (shared_16x8_to_mma_32x4_layout_sr_a,
from tilelang.intrinsics.mma_layout import (
shared_16x8_to_mma_32x4_layout_sr_a,
shared_16x16_to_mma_32x8_layout_sr_a,
shared_16x32_to_mma_32x16_layout_sr_a)
shared_16x32_to_mma_32x16_layout_sr_a,
)
lift = convert
......@@ -96,9 +98,22 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
is_m_first: bool | None = False,
thread_var: Var | None = None,
):
super().__init__(a_dtype, b_dtype, accum_dtype, a_transposed, b_transposed, block_row_warps,
block_col_warps, warp_row_tiles, warp_col_tiles, chunk, reduce_k,
num_elems_per_byte, is_m_first, thread_var)
super().__init__(
a_dtype,
b_dtype,
accum_dtype,
a_transposed,
b_transposed,
block_row_warps,
block_col_warps,
warp_row_tiles,
warp_col_tiles,
chunk,
reduce_k,
num_elems_per_byte,
is_m_first,
thread_var,
)
self._initialize_wgmma_prefix(self.n_dim)
def _assign_a_shared_layout(self, layout: Layout):
......@@ -112,12 +127,12 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
def _initialize_wgmma_prefix(self, n_dim: int = 16):
inst_m, inst_n = 64, gcd(self.warp_col_tiles, 256)
assert inst_n % 8 == 0, (
f"inst_n must be a multiple of 8, got {inst_n} "
f"(block_col_warps={self.block_col_warps}, warp_col_tiles={self.warp_col_tiles})")
f"inst_n must be a multiple of 8, got {inst_n} (block_col_warps={self.block_col_warps}, warp_col_tiles={self.warp_col_tiles})"
)
# Validate inst_n: Hopper WGMMA supports n in [8, 256] and multiple of 8
assert 8 <= inst_n <= 256, (
f"inst_n must be within [8, 256], got {inst_n} "
f"(block_col_warps={self.block_col_warps}, warp_col_tiles={self.warp_col_tiles})")
f"inst_n must be within [8, 256], got {inst_n} (block_col_warps={self.block_col_warps}, warp_col_tiles={self.warp_col_tiles})"
)
# 256 bits per instruction
inst_k = 256 // DataType(self.a_dtype).bits
self.wgmma_inst_m = inst_m
......@@ -160,13 +175,9 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
else:
raise ValueError(f"Unsupported swizzle mode: {layout}")
def wgmma(self,
A_region: BufferRegion,
B_region: BufferRegion,
C_region: BufferRegion,
clear_accum: PrimExpr = False,
wg_wait: int = 0):
def wgmma(
self, A_region: BufferRegion, B_region: BufferRegion, C_region: BufferRegion, clear_accum: PrimExpr = False, wg_wait: int = 0
):
if is_fragment(A_region):
return self.wgmma_rs(A_region, B_region, C_region, clear_accum, wg_wait)
......@@ -195,16 +206,13 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
elems_in_bytes = elems_in_bits // 8
a_swizzle_atom_elems = a_swizzle_mode.swizzle_byte_size() // elems_in_bytes
b_swizzle_atom_elems = n_dim if b_swizzle_mode.is_none(
) else b_swizzle_mode.swizzle_byte_size() // elems_in_bytes
b_swizzle_atom_elems = n_dim if b_swizzle_mode.is_none() else b_swizzle_mode.swizzle_byte_size() // elems_in_bytes
accum_bits = DataType(accum_dtype).bits
accum_regs = ((m_dim // 64) * warp_cols * local_size_out * accum_bits + 31) // 32
# by default, we utilize non-swizzle layout offset
a_leading_byte_offset = (8 * 8 * elems_in_bytes) if a_is_k_major else (8 * m_dim *
elems_in_bytes)
a_stride_byte_offset = (8 * k_dim * elems_in_bytes) if a_is_k_major else (8 * 8 *
elems_in_bytes)
a_leading_byte_offset = (8 * 8 * elems_in_bytes) if a_is_k_major else (8 * m_dim * elems_in_bytes)
a_stride_byte_offset = (8 * k_dim * elems_in_bytes) if a_is_k_major else (8 * 8 * elems_in_bytes)
if not a_swizzle_mode.is_none():
# swizzle mode doesn't require LBO/SBO to be 1
......@@ -220,19 +228,15 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
if a_m_axis_atoms <= 1:
a_leading_byte_offset = 0
else:
a_leading_byte_offset = 8 * a_swizzle_mode.swizzle_atom_size() * (
a_swizzle_mode.swizzle_byte_size() // elems_in_bytes)
a_leading_byte_offset = 8 * a_swizzle_mode.swizzle_atom_size() * (a_swizzle_mode.swizzle_byte_size() // elems_in_bytes)
if a_m_axis_atoms <= 1:
a_stride_byte_offset = 8 * elems_in_bytes * m_dim
else:
a_stride_byte_offset = 8 * elems_in_bytes * a_swizzle_atom_elems
b_leading_byte_offset = (8 * 8 * elems_in_bytes) if b_is_k_major else (8 * n_dim *
elems_in_bytes)
b_stride_byte_offset = (8 * k_dim *
elems_in_bytes) if b_is_k_major else (0 if n_dim == 8 else
(8 * 8 * elems_in_bytes))
b_leading_byte_offset = (8 * 8 * elems_in_bytes) if b_is_k_major else (8 * n_dim * elems_in_bytes)
b_stride_byte_offset = (8 * k_dim * elems_in_bytes) if b_is_k_major else (0 if n_dim == 8 else (8 * 8 * elems_in_bytes))
if not b_swizzle_mode.is_none():
# swizzle mode doesn't require LBO/SBO to be 1
# https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-leading-dimension-byte-offset
......@@ -275,12 +279,8 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
desc_a = T.alloc_wgmma_desc()
desc_b = T.alloc_wgmma_desc()
T.initialize_wgmma_descriptor(desc_a, A_ptr, a_swizzle_mode,
int(a_leading_byte_offset >> 4),
int(a_stride_byte_offset >> 4))
T.initialize_wgmma_descriptor(desc_b, B_ptr, b_swizzle_mode,
int(b_leading_byte_offset >> 4),
int(b_stride_byte_offset >> 4))
T.initialize_wgmma_descriptor(desc_a, A_ptr, a_swizzle_mode, int(a_leading_byte_offset >> 4), int(a_stride_byte_offset >> 4))
T.initialize_wgmma_descriptor(desc_b, B_ptr, b_swizzle_mode, int(b_leading_byte_offset >> 4), int(b_stride_byte_offset >> 4))
T.warpgroup_fence_operand(C_buf, num_regs=accum_regs)
T.warpgroup_arrive()
......@@ -291,21 +291,41 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
warp_i = (warp_m // 4) * num_inst_m + i
warp_j = warp_n * num_inst_n + j
A_offset = (
ki % ak_atom_size
) * micro_size_k + warp_i * 64 * a_swizzle_atom_elems + (
ki // ak_atom_size
) * m_dim * a_swizzle_atom_elems if a_is_k_major else warp_i * 64 * k_dim + ki * a_swizzle_atom_elems * micro_size_k
B_offset = (ki // bk_atom_size) * n_dim * b_swizzle_atom_elems + (
ki % bk_atom_size
) * micro_size_k + warp_j * wgmma_inst_n * b_swizzle_atom_elems if b_is_k_major else (
ki * b_swizzle_atom_elems * micro_size_k + warp_j * wgmma_inst_n *
(k_dim if n_dim // b_swizzle_atom_elems > 1 else 1))
(ki % ak_atom_size) * micro_size_k
+ warp_i * 64 * a_swizzle_atom_elems
+ (ki // ak_atom_size) * m_dim * a_swizzle_atom_elems
if a_is_k_major
else warp_i * 64 * k_dim + ki * a_swizzle_atom_elems * micro_size_k
)
B_offset = (
(ki // bk_atom_size) * n_dim * b_swizzle_atom_elems
+ (ki % bk_atom_size) * micro_size_k
+ warp_j * wgmma_inst_n * b_swizzle_atom_elems
if b_is_k_major
else (
ki * b_swizzle_atom_elems * micro_size_k
+ warp_j * wgmma_inst_n * (k_dim if n_dim // b_swizzle_atom_elems > 1 else 1)
)
)
C_offset = i * warp_cols * local_size_out + j * warp_cols * local_size_out // num_inst_n # 4 warps as an unit
T.ptx_wgmma_ss(accum_dtype, wgmma_prefix, a_is_k_major, b_is_k_major,
a_dtype_abbrv, b_dtype_abbrv, accum_dtype_abbrv, desc_a.data,
(A_offset * elems_in_bytes) >> 4, desc_b.data,
(B_offset * elems_in_bytes) >> 4, C_buf.data, C_offset,
scale_out, scale_in_a, scale_in_b)
T.ptx_wgmma_ss(
accum_dtype,
wgmma_prefix,
a_is_k_major,
b_is_k_major,
a_dtype_abbrv,
b_dtype_abbrv,
accum_dtype_abbrv,
desc_a.data,
(A_offset * elems_in_bytes) >> 4,
desc_b.data,
(B_offset * elems_in_bytes) >> 4,
C_buf.data,
C_offset,
scale_out,
scale_in_a,
scale_in_b,
)
T.warpgroup_commit_batch()
if wg_wait >= 0:
......@@ -314,12 +334,9 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
return _warp_mma(A_ptr, B_ptr, C_buf)
def wgmma_rs(self,
A_region: BufferRegion,
B_region: BufferRegion,
C_region: BufferRegion,
clear_accum: PrimExpr = False,
wg_wait: int = 0):
def wgmma_rs(
self, A_region: BufferRegion, B_region: BufferRegion, C_region: BufferRegion, clear_accum: PrimExpr = False, wg_wait: int = 0
):
local_size_a = self.local_size_a
local_size_out = self.local_size_out
a_dtype_abbrv = self.a_dtype_abbrv
......@@ -344,14 +361,10 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
b_is_k_major = self.b_transposed
b_swizzle_mode = self._determinate_swizzle_mode(B_region, self.b_shared_layout)
b_swizzle_atom_elems = n_dim if b_swizzle_mode.is_none(
) else b_swizzle_mode.swizzle_byte_size() // elems_in_bytes
b_leading_byte_offset = (8 * 8 * elems_in_bytes) if b_is_k_major else (8 * n_dim *
elems_in_bytes)
b_stride_byte_offset = (8 * k_dim *
elems_in_bytes) if b_is_k_major else (0 if n_dim == 8 else
(8 * 8 * elems_in_bytes))
b_swizzle_atom_elems = n_dim if b_swizzle_mode.is_none() else b_swizzle_mode.swizzle_byte_size() // elems_in_bytes
b_leading_byte_offset = (8 * 8 * elems_in_bytes) if b_is_k_major else (8 * n_dim * elems_in_bytes)
b_stride_byte_offset = (8 * k_dim * elems_in_bytes) if b_is_k_major else (0 if n_dim == 8 else (8 * 8 * elems_in_bytes))
if not b_swizzle_mode.is_none():
# swizzle mode doesn't require LBO/SBO to be 1
# https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-leading-dimension-byte-offset
......@@ -390,9 +403,7 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
tx, warp_n, warp_m = self.extract_thread_binding(thread_binding)
desc_b = T.alloc_wgmma_desc()
T.initialize_wgmma_descriptor(desc_b, B_ptr, b_swizzle_mode,
int(b_leading_byte_offset >> 4),
int(b_stride_byte_offset >> 4))
T.initialize_wgmma_descriptor(desc_b, B_ptr, b_swizzle_mode, int(b_leading_byte_offset >> 4), int(b_stride_byte_offset >> 4))
T.warpgroup_fence_operand(A_buf, num_regs=a_regs)
T.warpgroup_fence_operand(C_buf, num_regs=accum_regs)
T.warpgroup_arrive()
......@@ -405,11 +416,15 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
A_offset = ki * warp_rows * local_size_a + i * local_size_a
B_offset = (
ki // bk_atom_size
) * n_dim * b_swizzle_atom_elems + warp_j * wgmma_inst_n * b_swizzle_atom_elems + (
ki % bk_atom_size) * micro_size_k if b_is_k_major else (
ki * b_swizzle_atom_elems * micro_size_k + warp_j * wgmma_inst_n *
(k_dim if n_dim // b_swizzle_atom_elems > 1 else 1))
(ki // bk_atom_size) * n_dim * b_swizzle_atom_elems
+ warp_j * wgmma_inst_n * b_swizzle_atom_elems
+ (ki % bk_atom_size) * micro_size_k
if b_is_k_major
else (
ki * b_swizzle_atom_elems * micro_size_k
+ warp_j * wgmma_inst_n * (k_dim if n_dim // b_swizzle_atom_elems > 1 else 1)
)
)
C_offset = i * warp_cols * local_size_out + j * warp_cols * local_size_out // num_inst_n # 4 warps as an unit
T.ptx_wgmma_rs(
accum_dtype,
......@@ -460,6 +475,7 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
If `local_buf` is not detected to be a fragment buffer.
"""
from tilelang.utils import is_fragment
assert matrix in ["A"], "matrix should be A for WGMMA"
dtype = self.a_dtype
dtype_bits = DataType(dtype).bits
......@@ -488,8 +504,7 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
# the layout of mma.sync is row.col.
# so the b matrix expected a transposed basic layout
transform_func: Callable = None
transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a(
j, i)
transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a(j, i)
assert is_fragment(local_buf), f"local_buf must be a fragment, but got {local_buf.scope()}"
......@@ -531,20 +546,12 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
replicate = block_col_warps
if is_sr_axis_order:
warp_fragment = base_fragment.repeat([block_s, 1],
repeat_on_thread=True,
lower_dim_first=False).replicate(replicate)
block_fragment = warp_fragment.repeat([warp_s, warp_r],
repeat_on_thread=False,
lower_dim_first=False)
warp_fragment = base_fragment.repeat([block_s, 1], repeat_on_thread=True, lower_dim_first=False).replicate(replicate)
block_fragment = warp_fragment.repeat([warp_s, warp_r], repeat_on_thread=False, lower_dim_first=False)
else:
# rs condition, transposed_a matrix
warp_fragment = base_fragment.repeat([1, block_s],
repeat_on_thread=True,
lower_dim_first=False).replicate(replicate)
block_fragment = warp_fragment.repeat([warp_r, warp_s],
repeat_on_thread=False,
lower_dim_first=True)
warp_fragment = base_fragment.repeat([1, block_s], repeat_on_thread=True, lower_dim_first=False).replicate(replicate)
block_fragment = warp_fragment.repeat([warp_r, warp_s], repeat_on_thread=False, lower_dim_first=True)
return block_fragment
......
......@@ -7,23 +7,19 @@ from tilelang import _ffi_api
@tvm_ffi.register_object("tl.Fill")
class Fill(Node, Scriptable):
...
class Fill(Node, Scriptable): ...
@tvm_ffi.register_object("tl.AtomicAdd")
class AtomicAdd(Node, Scriptable):
...
class AtomicAdd(Node, Scriptable): ...
@tvm_ffi.register_object("tl.Copy")
class Copy(Node, Scriptable):
...
class Copy(Node, Scriptable): ...
@tvm_ffi.register_object("tl.Conv2DIm2Col")
class Conv2DIm2ColOp(Node, Scriptable):
...
class Conv2DIm2ColOp(Node, Scriptable): ...
@tvm_ffi.register_object("tl.GemmWarpPolicy")
......@@ -32,10 +28,8 @@ class GemmWarpPolicy(Node, Scriptable):
m_warp: int
n_warp: int
def compute_warp_partition(self, M: int, N: int, block_size: int, target: Target,
is_wgmma: bool):
_ffi_api.GemmWarpPolicyComputeWarpPartition(self, int(M), int(N), int(block_size), target,
is_wgmma)
def compute_warp_partition(self, M: int, N: int, block_size: int, target: Target, is_wgmma: bool):
_ffi_api.GemmWarpPolicyComputeWarpPartition(self, int(M), int(N), int(block_size), target, is_wgmma)
return self.m_warp, self.n_warp
......@@ -45,48 +39,38 @@ class GemmSPWarpPolicy(Node, Scriptable):
m_warp: int
n_warp: int
def compute_warp_partition(self, M: int, N: int, block_size: int, target: Target,
is_wgmma: bool, bits: int):
_ffi_api.GemmSPWarpPolicyComputeWarpPartition(self, int(M), int(N), int(block_size), target,
is_wgmma, bits)
def compute_warp_partition(self, M: int, N: int, block_size: int, target: Target, is_wgmma: bool, bits: int):
_ffi_api.GemmSPWarpPolicyComputeWarpPartition(self, int(M), int(N), int(block_size), target, is_wgmma, bits)
return self.m_warp, self.n_warp
@tvm_ffi.register_object("tl.Gemm")
class Gemm(Node, Scriptable):
...
class Gemm(Node, Scriptable): ...
@tvm_ffi.register_object("tl.GemmSP")
class GemmSP(Node, Scriptable):
...
class GemmSP(Node, Scriptable): ...
@tvm_ffi.register_object("tl.FinalizeReducerOp")
class FinalizeReducerOp(Node, Scriptable):
...
class FinalizeReducerOp(Node, Scriptable): ...
@tvm_ffi.register_object("tl.ParallelOp")
class ParallelOp(Node, Scriptable):
...
class ParallelOp(Node, Scriptable): ...
@tvm_ffi.register_object("tl.ReduceOp")
class ReduceOp(Node, Scriptable):
...
class ReduceOp(Node, Scriptable): ...
@tvm_ffi.register_object("tl.CumSumOp")
class CumSumOp(Node, Scriptable):
...
class CumSumOp(Node, Scriptable): ...
@tvm_ffi.register_object("tl.RegionOp")
class RegionOp(Node, Scriptable):
...
class RegionOp(Node, Scriptable): ...
@tvm_ffi.register_object("tl.ReduceType")
class ReduceType(Node, Scriptable):
...
class ReduceType(Node, Scriptable): ...
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment