Unverified Commit 29051439 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Lint] Phaseout Yapf format and embrace ruff format (#1417)

parent e84b24bc
...@@ -16,12 +16,7 @@ from tvm.base import py_str ...@@ -16,12 +16,7 @@ from tvm.base import py_str
from tvm.contrib.rocm import get_rocm_arch, find_rocm_path from tvm.contrib.rocm import get_rocm_arch, find_rocm_path
def compile_hip(code, def compile_hip(code, target_format="hsaco", arch=None, options=None, path_target=None, verbose=False):
target_format="hsaco",
arch=None,
options=None,
path_target=None,
verbose=False):
"""Compile HIP code with hipcc. """Compile HIP code with hipcc.
Parameters Parameters
...@@ -61,7 +56,7 @@ def compile_hip(code, ...@@ -61,7 +56,7 @@ def compile_hip(code,
file_target = path_target if path_target else temp_target file_target = path_target if path_target else temp_target
cmd = ["hipcc"] cmd = ["hipcc"]
cmd += ["-O3", '-c'] cmd += ["-O3", "-c"]
if isinstance(arch, str): if isinstance(arch, str):
cmd += [f"--offload-arch={arch}"] cmd += [f"--offload-arch={arch}"]
if target_format == "hsaco": if target_format == "hsaco":
......
# pylint: disable=invalid-name # pylint: disable=invalid-name
# modified from apache tvm python/tvm/contrib/nvcc.py # modified from apache tvm python/tvm/contrib/nvcc.py
"""Utility to invoke nvcc compiler in the system""" """Utility to invoke nvcc compiler in the system"""
from __future__ import annotations from __future__ import annotations
import os import os
...@@ -18,12 +19,7 @@ from tvm.base import py_str ...@@ -18,12 +19,7 @@ from tvm.base import py_str
from tvm.contrib import utils from tvm.contrib import utils
def compile_cuda(code, def compile_cuda(code, target_format="ptx", arch=None, options=None, path_target=None, verbose=False):
target_format="ptx",
arch=None,
options=None,
path_target=None,
verbose=False):
"""Compile cuda code with NVCC from env. """Compile cuda code with NVCC from env.
Parameters Parameters
...@@ -67,7 +63,7 @@ def compile_cuda(code, ...@@ -67,7 +63,7 @@ def compile_cuda(code,
temp_target = temp.relpath(f"{file_name}.{target_format}") temp_target = temp.relpath(f"{file_name}.{target_format}")
pass_context = tvm.get_global_func("transform.GetCurrentPassContext")() pass_context = tvm.get_global_func("transform.GetCurrentPassContext")()
kernels_output_dir = (pass_context.config.get("cuda.kernels_output_dir", None)) kernels_output_dir = pass_context.config.get("cuda.kernels_output_dir", None)
if kernels_output_dir is not None: if kernels_output_dir is not None:
if not os.path.isdir(kernels_output_dir): if not os.path.isdir(kernels_output_dir):
os.makedirs(kernels_output_dir) os.makedirs(kernels_output_dir)
...@@ -114,10 +110,7 @@ def compile_cuda(code, ...@@ -114,10 +110,7 @@ def compile_cuda(code,
print(py_str(out)) print(py_str(out))
if proc.returncode != 0: if proc.returncode != 0:
msg = f"{code}\n" \ msg = f"{code}\nCompilation error:\n{py_str(out)}\nCommand: {' '.join(cmd)}\n"
f"Compilation error:\n" \
f"{py_str(out)}\n" \
f"Command: {' '.join(cmd)}\n"
raise RuntimeError(msg) raise RuntimeError(msg)
with open(file_target, "rb") as f: with open(file_target, "rb") as f:
...@@ -165,6 +158,7 @@ def default_compile_options(compile_flags: list[str] | None = None) -> list[str] ...@@ -165,6 +158,7 @@ def default_compile_options(compile_flags: list[str] | None = None) -> list[str]
# (e.g., multiple "-gencode" pairs or repeated "-Xcompiler" entries). # (e.g., multiple "-gencode" pairs or repeated "-Xcompiler" entries).
if compile_flags: if compile_flags:
import shlex import shlex
for flag in compile_flags: for flag in compile_flags:
# Split each string like a shell would, preserving quoted args # Split each string like a shell would, preserving quoted args
tokens = shlex.split(flag) if isinstance(flag, str) else [str(flag)] tokens = shlex.split(flag) if isinstance(flag, str) else [str(flag)]
...@@ -172,9 +166,7 @@ def default_compile_options(compile_flags: list[str] | None = None) -> list[str] ...@@ -172,9 +166,7 @@ def default_compile_options(compile_flags: list[str] | None = None) -> list[str]
return options return options
def get_ptx_from_source(code: str, def get_ptx_from_source(code: str, compile_flags: list[str] | None = None, verbose: bool = False) -> str:
compile_flags: list[str] | None = None,
verbose: bool = False) -> str:
""" """
Compile CUDA C++ source to PTX using NVCC and return as text. Compile CUDA C++ source to PTX using NVCC and return as text.
...@@ -212,9 +204,7 @@ def _find_tool(name: str) -> str | None: ...@@ -212,9 +204,7 @@ def _find_tool(name: str) -> str | None:
return None return None
def get_sass_from_source(code: str, def get_sass_from_source(code: str, compile_flags: list[str] | None = None, verbose: bool = False) -> str:
compile_flags: list[str] | None = None,
verbose: bool = False) -> str:
""" """
Compile CUDA C++ source to CUBIN and disassemble to SASS. Compile CUDA C++ source to CUBIN and disassemble to SASS.
...@@ -246,9 +236,7 @@ def get_sass_from_source(code: str, ...@@ -246,9 +236,7 @@ def get_sass_from_source(code: str,
cand_nvdisasm = _find_tool("nvdisasm") cand_nvdisasm = _find_tool("nvdisasm")
cand_cuobjdump = _find_tool("cuobjdump") cand_cuobjdump = _find_tool("cuobjdump")
if not cand_nvdisasm and not cand_cuobjdump: if not cand_nvdisasm and not cand_cuobjdump:
raise RuntimeError( raise RuntimeError("Cannot find 'nvdisasm' or 'cuobjdump'. Please ensure CUDA toolkit is installed and in PATH.")
"Cannot find 'nvdisasm' or 'cuobjdump'. Please ensure CUDA toolkit is installed and in PATH."
)
last_err: str | None = None last_err: str | None = None
try: try:
# Attempt nvdisasm first # Attempt nvdisasm first
...@@ -268,8 +256,7 @@ def get_sass_from_source(code: str, ...@@ -268,8 +256,7 @@ def get_sass_from_source(code: str,
return text return text
last_err = f"{tool_name} rc={proc.returncode}, output:\n{text}" last_err = f"{tool_name} rc={proc.returncode}, output:\n{text}"
# If we reach here, all attempts failed # If we reach here, all attempts failed
raise RuntimeError(f"SASS disassembly failed. Tried tools: " raise RuntimeError(f"SASS disassembly failed. Tried tools: {', '.join(name for name, _ in tools_to_try)}\n{last_err or ''}")
f"{', '.join(name for name, _ in tools_to_try)}\n{last_err or ''}")
finally: finally:
with contextlib.suppress(Exception): with contextlib.suppress(Exception):
os.remove(cubin_path) os.remove(cubin_path)
...@@ -438,8 +425,7 @@ def get_target_compute_version(target=None): ...@@ -438,8 +425,7 @@ def get_target_compute_version(target=None):
if tvm.cuda(0).exist: if tvm.cuda(0).exist:
return tvm.cuda(0).compute_version return tvm.cuda(0).compute_version
raise ValueError("No CUDA architecture was specified or GPU detected." raise ValueError("No CUDA architecture was specified or GPU detected.Try specifying it by adding '-arch=sm_xx' to your target.")
"Try specifying it by adding '-arch=sm_xx' to your target.")
def parse_compute_version(compute_version) -> tuple[int, int]: def parse_compute_version(compute_version) -> tuple[int, int]:
...@@ -524,7 +510,8 @@ def have_tensorcore(compute_version=None, target=None): ...@@ -524,7 +510,8 @@ def have_tensorcore(compute_version=None, target=None):
warnings.warn( warnings.warn(
"Tensorcore will be disabled due to no CUDA architecture specified." "Tensorcore will be disabled due to no CUDA architecture specified."
"Try specifying it by adding '-arch=sm_xx' to your target.", "Try specifying it by adding '-arch=sm_xx' to your target.",
stacklevel=2) stacklevel=2,
)
return False return False
compute_version = target.attrs["arch"] compute_version = target.attrs["arch"]
# Compute version will be in the form "sm_{major}{minor}" # Compute version will be in the form "sm_{major}{minor}"
......
...@@ -11,11 +11,13 @@ def get_nvrtc_version() -> tuple[int, int]: ...@@ -11,11 +11,13 @@ def get_nvrtc_version() -> tuple[int, int]:
return (major, minor) return (major, minor)
def compile_cuda(code: str, def compile_cuda(
target_format: Literal["ptx", "cubin"] = "ptx", code: str,
arch: int | None = None, target_format: Literal["ptx", "cubin"] = "ptx",
options: str | list[str] | None = None, arch: int | None = None,
verbose: bool = False) -> bytearray: options: str | list[str] | None = None,
verbose: bool = False,
) -> bytearray:
"""Compile cuda code with NVRTC. """Compile cuda code with NVRTC.
Parameters Parameters
...@@ -43,8 +45,7 @@ def compile_cuda(code: str, ...@@ -43,8 +45,7 @@ def compile_cuda(code: str,
if arch is None: if arch is None:
# If None, then it will use `tvm.target.Target.current().arch`. # If None, then it will use `tvm.target.Target.current().arch`.
# Target arch could be a str like "80", "90", "90a", etc. # Target arch could be a str like "80", "90", "90a", etc.
major, minor = parse_compute_version( major, minor = parse_compute_version(get_target_compute_version(Target.current(allow_none=True)))
get_target_compute_version(Target.current(allow_none=True)))
arch = major * 10 + minor arch = major * 10 + minor
prefix = "compute" if target_format == "ptx" else "sm" prefix = "compute" if target_format == "ptx" else "sm"
suffix = "a" if arch >= 90 else "" suffix = "a" if arch >= 90 else ""
...@@ -77,8 +78,7 @@ def compile_cuda(code: str, ...@@ -77,8 +78,7 @@ def compile_cuda(code: str,
compile_result = nvrtc.nvrtcCompileProgram(program, len(options_bytes), options_bytes)[0] compile_result = nvrtc.nvrtcCompileProgram(program, len(options_bytes), options_bytes)[0]
if compile_result != nvrtc.nvrtcResult.NVRTC_SUCCESS: if compile_result != nvrtc.nvrtcResult.NVRTC_SUCCESS:
msg = f"{code}\n" \ msg = f"{code}\nCompilation error:\n"
f"Compilation error:\n"
if verbose: if verbose:
result, log_size = nvrtc.nvrtcGetProgramLogSize(program) result, log_size = nvrtc.nvrtcGetProgramLogSize(program)
assert result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f"Failed to get program log size: {result}" assert result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f"Failed to get program log size: {result}"
...@@ -105,7 +105,6 @@ def compile_cuda(code: str, ...@@ -105,7 +105,6 @@ def compile_cuda(code: str,
assert result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f"Failed to get PTX: {result}" assert result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f"Failed to get PTX: {result}"
# Destroy handler # Destroy handler
assert nvrtc.nvrtcDestroyProgram( assert nvrtc.nvrtcDestroyProgram(program)[0] == nvrtc.nvrtcResult.NVRTC_SUCCESS, f"Failed to destroy program: {result}"
program)[0] == nvrtc.nvrtcResult.NVRTC_SUCCESS, f"Failed to destroy program: {result}"
return result_bytes return result_bytes
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
"""Utility for ROCm backend""" """Utility for ROCm backend"""
# ruff: noqa # ruff: noqa
import re import re
import subprocess import subprocess
...@@ -255,9 +256,11 @@ def get_rocm_arch(rocm_path="/opt/rocm"): ...@@ -255,9 +256,11 @@ def get_rocm_arch(rocm_path="/opt/rocm"):
gpu_arch = match.group(1) gpu_arch = match.group(1)
return gpu_arch return gpu_arch
except subprocess.CalledProcessError: except subprocess.CalledProcessError:
print(f"Unable to execute rocminfo command, \ print(
f"Unable to execute rocminfo command, \
please ensure ROCm is installed and you have an AMD GPU on your system.\ please ensure ROCm is installed and you have an AMD GPU on your system.\
using default {gpu_arch}.") using default {gpu_arch}."
)
return gpu_arch return gpu_arch
......
"""The compiler for TL programs.""" """The compiler for TL programs."""
from __future__ import annotations from __future__ import annotations
import os import os
...@@ -28,14 +29,13 @@ def is_cpu_device_backend(target: Target): ...@@ -28,14 +29,13 @@ def is_cpu_device_backend(target: Target):
def has_device_kernel_launch(attrs) -> bool: def has_device_kernel_launch(attrs) -> bool:
"""Check if the attributes indicate a device kernel launch.""" """Check if the attributes indicate a device kernel launch."""
return bool(attrs and "calling_conv" in attrs and return bool(attrs and "calling_conv" in attrs and attrs["calling_conv"] == CallingConv.DEVICE_KERNEL_LAUNCH)
attrs["calling_conv"] == CallingConv.DEVICE_KERNEL_LAUNCH)
def is_device_call_c_device(func: tir.PrimFunc): def is_device_call_c_device(func: tir.PrimFunc):
attrs = func.attrs attrs = func.attrs
calling_conv = attrs.get("calling_conv", CallingConv.DEFAULT) calling_conv = attrs.get("calling_conv", CallingConv.DEFAULT)
is_cpacked = (calling_conv == CallingConv.C_PACKED_FUNC) is_cpacked = calling_conv == CallingConv.C_PACKED_FUNC
# Check if it's a C target # Check if it's a C target
if "target" in attrs and attrs["target"].kind.name == "c" and not is_cpacked: if "target" in attrs and attrs["target"].kind.name == "c" and not is_cpacked:
...@@ -141,16 +141,16 @@ def extrac_params(func: tir.PrimFunc) -> list[KernelParam]: ...@@ -141,16 +141,16 @@ def extrac_params(func: tir.PrimFunc) -> list[KernelParam]:
if var in func.buffer_map: if var in func.buffer_map:
tensor_types.append(KernelParam.from_buffer(func.buffer_map[var])) tensor_types.append(KernelParam.from_buffer(func.buffer_map[var]))
else: else:
if var.dtype == 'handle': if var.dtype == "handle":
raise ValueError( raise ValueError(
f'Handle parameter {var} must be mapped to a buffer.\n' f"Handle parameter {var} must be mapped to a buffer.\n"
f'Please use T.tensor({var.name}, shape=..., dtype=...) to map it to a buffer.') f"Please use T.tensor({var.name}, shape=..., dtype=...) to map it to a buffer."
)
tensor_types.append(KernelParam.from_var(var)) tensor_types.append(KernelParam.from_var(var))
return tensor_types return tensor_types
def canon_target_host(target: str | Target, target_host: str | Target | None): def canon_target_host(target: str | Target, target_host: str | Target | None):
if not target_host: if not target_host:
target_host = "llvm" if tvm.runtime.enabled("llvm") else "c" target_host = "llvm" if tvm.runtime.enabled("llvm") else "c"
...@@ -195,11 +195,9 @@ def device_codegen_without_compile(device_mod: tvm.IRModule, target: Target) -> ...@@ -195,11 +195,9 @@ def device_codegen_without_compile(device_mod: tvm.IRModule, target: Target) ->
device_mod = tilelang.transform.LowerIntrin()(device_mod) device_mod = tilelang.transform.LowerIntrin()(device_mod)
device_mod = tir.transform.Simplify()(device_mod) device_mod = tir.transform.Simplify()(device_mod)
if target.kind.name == "cuda": if target.kind.name == "cuda":
device_mod = tvm.ffi.get_global_func("target.build.tilelang_cuda_without_compile")( device_mod = tvm.ffi.get_global_func("target.build.tilelang_cuda_without_compile")(device_mod, target)
device_mod, target)
elif target.kind.name == "hip": elif target.kind.name == "hip":
device_mod = tvm.ffi.get_global_func("target.build.tilelang_hip_without_compile")( device_mod = tvm.ffi.get_global_func("target.build.tilelang_hip_without_compile")(device_mod, target)
device_mod, target)
elif target.kind.name == "c": elif target.kind.name == "c":
device_mod = tvm.ffi.get_global_func("target.build.tilelang_cpp")(device_mod, target) device_mod = tvm.ffi.get_global_func("target.build.tilelang_cpp")(device_mod, target)
elif target.kind.name == "llvm": elif target.kind.name == "llvm":
...@@ -222,12 +220,12 @@ def lower( ...@@ -222,12 +220,12 @@ def lower(
enable_host_codegen=False, enable_host_codegen=False,
enable_device_compile=False, enable_device_compile=False,
) -> CompiledArtifact: ) -> CompiledArtifact:
''' """
enable_host_codegen: whether to enable host codegen, default is False, as we have our enable_host_codegen: whether to enable host codegen, default is False, as we have our
own host codegen implementation in jit. own host codegen implementation in jit.
enable_device_compile: whether to enable device codegen, default is False, as we have our enable_device_compile: whether to enable device codegen, default is False, as we have our
own device codegen implementation in jit. own device codegen implementation in jit.
''' """
mod = func_or_mod mod = func_or_mod
params = None params = None
...@@ -259,14 +257,11 @@ def lower( ...@@ -259,14 +257,11 @@ def lower(
host_mod = tir.transform.Filter(_is_host_call)(mod) host_mod = tir.transform.Filter(_is_host_call)(mod)
device_mod = tir.transform.Filter(_is_device_call)(mod) device_mod = tir.transform.Filter(_is_device_call)(mod)
codegen_mod = device_codegen( codegen_mod = device_codegen(device_mod, target) if enable_device_compile else device_codegen_without_compile(device_mod, target)
device_mod, target) if enable_device_compile else device_codegen_without_compile(
device_mod, target)
if enable_host_codegen: if enable_host_codegen:
host_mod = host_codegen(host_mod, target_host) host_mod = host_codegen(host_mod, target_host)
host_mod.import_module(codegen_mod) host_mod.import_module(codegen_mod)
return CompiledArtifact( return CompiledArtifact(host_mod, device_mod, params, codegen_mod.inspect_source(), rt_mod=host_mod)
host_mod, device_mod, params, codegen_mod.inspect_source(), rt_mod=host_mod)
return CompiledArtifact(host_mod, device_mod, params, codegen_mod.inspect_source()) return CompiledArtifact(host_mod, device_mod, params, codegen_mod.inspect_source())
"""The profiler and convert to torch utils""" """The profiler and convert to torch utils"""
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
...@@ -14,6 +15,7 @@ class KernelParam: ...@@ -14,6 +15,7 @@ class KernelParam:
Represents parameters for a kernel operation, storing dtype and shape information. Represents parameters for a kernel operation, storing dtype and shape information.
Used to describe tensor or scalar parameters in TVM/PyTorch interop. Used to describe tensor or scalar parameters in TVM/PyTorch interop.
""" """
dtype: torch.dtype # PyTorch data type of the parameter dtype: torch.dtype # PyTorch data type of the parameter
shape: list[int | Var] # List of dimensions, can be integers or TVM variables shape: list[int | Var] # List of dimensions, can be integers or TVM variables
...@@ -109,6 +111,7 @@ class CompiledArtifact: ...@@ -109,6 +111,7 @@ class CompiledArtifact:
Represents a compiled kernel artifact containing both host and device code. Represents a compiled kernel artifact containing both host and device code.
Stores all necessary components for kernel execution in the TVM runtime. Stores all necessary components for kernel execution in the TVM runtime.
""" """
host_mod: tvm.IRModule # Host-side TVM IR module for managing kernel execution host_mod: tvm.IRModule # Host-side TVM IR module for managing kernel execution
device_mod: tvm.IRModule # Device-side TVM IR module containing the actual kernel code device_mod: tvm.IRModule # Device-side TVM IR module containing the actual kernel code
params: list[KernelParam] # List of parameters (tensors/scalars) used by the kernel params: list[KernelParam] # List of parameters (tensors/scalars) used by the kernel
......
...@@ -6,8 +6,7 @@ from tilelang.transform import PassContext ...@@ -6,8 +6,7 @@ from tilelang.transform import PassContext
from tilelang.contrib.nvcc import have_tma, is_hopper from tilelang.contrib.nvcc import have_tma, is_hopper
def allow_warp_specialized(pass_ctx: PassContext | None = None, def allow_warp_specialized(pass_ctx: PassContext | None = None, target: Target | None = None) -> bool:
target: Target | None = None) -> bool:
# avoid circular import # avoid circular import
from tilelang.jit.adapter.utils import is_cuda_target from tilelang.jit.adapter.utils import is_cuda_target
...@@ -19,8 +18,7 @@ def allow_warp_specialized(pass_ctx: PassContext | None = None, ...@@ -19,8 +18,7 @@ def allow_warp_specialized(pass_ctx: PassContext | None = None,
return not disable_warp_specialized return not disable_warp_specialized
def allow_tma_and_warp_specialized(pass_ctx: PassContext | None = None, def allow_tma_and_warp_specialized(pass_ctx: PassContext | None = None, target: Target | None = None) -> bool:
target: Target | None = None) -> bool:
if pass_ctx is None: if pass_ctx is None:
pass_ctx = tilelang.transform.get_pass_context() pass_ctx = tilelang.transform.get_pass_context()
if not have_tma(target): if not have_tma(target):
...@@ -47,12 +45,10 @@ def allow_global_thread_synchronization(pass_ctx: PassContext | None = None) -> ...@@ -47,12 +45,10 @@ def allow_global_thread_synchronization(pass_ctx: PassContext | None = None) ->
return enable_global_thread_sync return enable_global_thread_sync
def should_enable_aggressive_merge(pass_ctx: PassContext | None = None, def should_enable_aggressive_merge(pass_ctx: PassContext | None = None, target: Target | None = None) -> bool:
target: Target | None = None) -> bool:
if pass_ctx is None: if pass_ctx is None:
pass_ctx = tilelang.transform.get_pass_context() pass_ctx = tilelang.transform.get_pass_context()
enable_aggressive_merge = bool( enable_aggressive_merge = bool(pass_ctx.config.get(tilelang.PassConfigKey.TL_ENABLE_AGGRESSIVE_SHARED_MEMORY_MERGE, False))
pass_ctx.config.get(tilelang.PassConfigKey.TL_ENABLE_AGGRESSIVE_SHARED_MEMORY_MERGE, False))
if allow_warp_specialized(pass_ctx=pass_ctx, target=target): if allow_warp_specialized(pass_ctx=pass_ctx, target=target):
# This is a workaround to avoid the bug in the MergeSharedMemoryAllocations pass # This is a workaround to avoid the bug in the MergeSharedMemoryAllocations pass
# when warp specialization is enabled, as different warp threads may access different # when warp specialization is enabled, as different warp threads may access different
...@@ -88,7 +84,7 @@ def get_layout_visual_formats(pass_ctx: PassContext | None = None) -> list[str]: ...@@ -88,7 +84,7 @@ def get_layout_visual_formats(pass_ctx: PassContext | None = None) -> list[str]:
return ["txt", "png", "pdf", "svg"] return ["txt", "png", "pdf", "svg"]
if "," in formats_str: if "," in formats_str:
formats_list = [f.strip() for f in formats_str.split(',')] formats_list = [f.strip() for f in formats_str.split(",")]
else: else:
formats_list = [formats_str] formats_list = [formats_str]
...@@ -257,9 +253,7 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: ...@@ -257,9 +253,7 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
# MergeSharedMemoryAllocations must be applied after SplitHostDevice # MergeSharedMemoryAllocations must be applied after SplitHostDevice
# because the merged allocation site is at the beginning of each device function # because the merged allocation site is at the beginning of each device function
enable_aggressive_merge = should_enable_aggressive_merge(pass_ctx=pass_ctx, target=target) enable_aggressive_merge = should_enable_aggressive_merge(pass_ctx=pass_ctx, target=target)
mod = tilelang.transform.MergeSharedMemoryAllocations( mod = tilelang.transform.MergeSharedMemoryAllocations(enable_aggressive_merge=enable_aggressive_merge)(mod)
enable_aggressive_merge=enable_aggressive_merge)(
mod)
mod = tilelang.transform.ThreadSync("shared")(mod) mod = tilelang.transform.ThreadSync("shared")(mod)
mod = tilelang.transform.ThreadSync("shared.dyn")(mod) mod = tilelang.transform.ThreadSync("shared.dyn")(mod)
# Inject PTX async copy must behind the thread sync pass # Inject PTX async copy must behind the thread sync pass
......
...@@ -10,36 +10,34 @@ from dataclasses import dataclass ...@@ -10,36 +10,34 @@ from dataclasses import dataclass
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# SETUP ENVIRONMENT VARIABLES # SETUP ENVIRONMENT VARIABLES
CUTLASS_NOT_FOUND_MESSAGE = ("CUTLASS is not installed or found in the expected path") CUTLASS_NOT_FOUND_MESSAGE = "CUTLASS is not installed or found in the expected path"
", which may lead to compilation bugs when utilize tilelang backend." ", which may lead to compilation bugs when utilize tilelang backend."
COMPOSABLE_KERNEL_NOT_FOUND_MESSAGE = ( COMPOSABLE_KERNEL_NOT_FOUND_MESSAGE = "Composable Kernel is not installed or found in the expected path"
"Composable Kernel is not installed or found in the expected path")
", which may lead to compilation bugs when utilize tilelang backend." ", which may lead to compilation bugs when utilize tilelang backend."
TL_TEMPLATE_NOT_FOUND_MESSAGE = ("TileLang is not installed or found in the expected path") TL_TEMPLATE_NOT_FOUND_MESSAGE = "TileLang is not installed or found in the expected path"
", which may lead to compilation bugs when utilize tilelang backend." ", which may lead to compilation bugs when utilize tilelang backend."
TVM_LIBRARY_NOT_FOUND_MESSAGE = ("TVM is not installed or found in the expected path") TVM_LIBRARY_NOT_FOUND_MESSAGE = "TVM is not installed or found in the expected path"
TL_ROOT = os.path.dirname(os.path.abspath(__file__)) TL_ROOT = os.path.dirname(os.path.abspath(__file__))
# Only expose the internal lib directory to sys.path to avoid shadowing # Only expose the internal lib directory to sys.path to avoid shadowing
# common top-level module names (e.g., utils, analysis) from user projects. # common top-level module names (e.g., utils, analysis) from user projects.
TL_LIBS = [os.path.join(TL_ROOT, 'lib')] TL_LIBS = [os.path.join(TL_ROOT, "lib")]
TL_LIBS = [i for i in TL_LIBS if os.path.exists(i)] TL_LIBS = [i for i in TL_LIBS if os.path.exists(i)]
DEV = False DEV = False
THIRD_PARTY_ROOT = os.path.join(TL_ROOT, '3rdparty') THIRD_PARTY_ROOT = os.path.join(TL_ROOT, "3rdparty")
if not os.path.exists(THIRD_PARTY_ROOT): if not os.path.exists(THIRD_PARTY_ROOT):
DEV = True DEV = True
tl_dev_root = os.path.dirname(TL_ROOT) tl_dev_root = os.path.dirname(TL_ROOT)
dev_lib_root = os.path.join(tl_dev_root, 'build') dev_lib_root = os.path.join(tl_dev_root, "build")
# In dev builds, place artifacts under build/lib and point search path there # In dev builds, place artifacts under build/lib and point search path there
# to avoid adding the entire build root to sys.path. # to avoid adding the entire build root to sys.path.
TL_LIBS = [os.path.join(dev_lib_root, 'lib'), os.path.join(dev_lib_root, 'tvm')] TL_LIBS = [os.path.join(dev_lib_root, "lib"), os.path.join(dev_lib_root, "tvm")]
THIRD_PARTY_ROOT = os.path.join(tl_dev_root, '3rdparty') THIRD_PARTY_ROOT = os.path.join(tl_dev_root, "3rdparty")
logger.warning(f'Loading tilelang libs from dev root: {dev_lib_root}') logger.warning(f"Loading tilelang libs from dev root: {dev_lib_root}")
assert TL_LIBS and all( assert TL_LIBS and all(os.path.exists(i) for i in TL_LIBS), f"tilelang lib root do not exists: {TL_LIBS}"
os.path.exists(i) for i in TL_LIBS), f'tilelang lib root do not exists: {TL_LIBS}'
for lib in TL_LIBS: for lib in TL_LIBS:
if lib not in sys.path: if lib not in sys.path:
...@@ -52,7 +50,7 @@ def _find_cuda_home() -> str: ...@@ -52,7 +50,7 @@ def _find_cuda_home() -> str:
Adapted from https://github.com/pytorch/pytorch/blob/main/torch/utils/cpp_extension.py Adapted from https://github.com/pytorch/pytorch/blob/main/torch/utils/cpp_extension.py
""" """
# Guess #1 # Guess #1
cuda_home = os.environ.get('CUDA_HOME') or os.environ.get('CUDA_PATH') cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH")
if cuda_home is None: if cuda_home is None:
# Guess #2 # Guess #2
nvcc_path = shutil.which("nvcc") nvcc_path = shutil.which("nvcc")
...@@ -70,15 +68,15 @@ def _find_cuda_home() -> str: ...@@ -70,15 +68,15 @@ def _find_cuda_home() -> str:
else: else:
# Guess #3 # Guess #3
if sys.platform == 'win32': if sys.platform == "win32":
cuda_homes = glob.glob('C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v*.*') cuda_homes = glob.glob("C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v*.*")
cuda_home = '' if len(cuda_homes) == 0 else cuda_homes[0] cuda_home = "" if len(cuda_homes) == 0 else cuda_homes[0]
else: else:
# Linux/macOS # Linux/macOS
if os.path.exists('/usr/local/cuda'): if os.path.exists("/usr/local/cuda"):
cuda_home = '/usr/local/cuda' cuda_home = "/usr/local/cuda"
elif os.path.exists('/opt/nvidia/hpc_sdk/Linux_x86_64'): elif os.path.exists("/opt/nvidia/hpc_sdk/Linux_x86_64"):
cuda_home = '/opt/nvidia/hpc_sdk/Linux_x86_64' cuda_home = "/opt/nvidia/hpc_sdk/Linux_x86_64"
# Validate found path # Validate found path
if cuda_home is None or not os.path.exists(cuda_home): if cuda_home is None or not os.path.exists(cuda_home):
...@@ -89,13 +87,13 @@ def _find_cuda_home() -> str: ...@@ -89,13 +87,13 @@ def _find_cuda_home() -> str:
def _find_rocm_home() -> str: def _find_rocm_home() -> str:
"""Find the ROCM install path.""" """Find the ROCM install path."""
rocm_home = os.environ.get('ROCM_PATH') or os.environ.get('ROCM_HOME') rocm_home = os.environ.get("ROCM_PATH") or os.environ.get("ROCM_HOME")
if rocm_home is None: if rocm_home is None:
rocmcc_path = shutil.which("hipcc") rocmcc_path = shutil.which("hipcc")
if rocmcc_path is not None: if rocmcc_path is not None:
rocm_home = os.path.dirname(os.path.dirname(rocmcc_path)) rocm_home = os.path.dirname(os.path.dirname(rocmcc_path))
else: else:
rocm_home = '/opt/rocm' rocm_home = "/opt/rocm"
if not os.path.exists(rocm_home): if not os.path.exists(rocm_home):
rocm_home = None rocm_home = None
return rocm_home if rocm_home is not None else "" return rocm_home if rocm_home is not None else ""
...@@ -104,6 +102,7 @@ def _find_rocm_home() -> str: ...@@ -104,6 +102,7 @@ def _find_rocm_home() -> str:
# Cache control # Cache control
class CacheState: class CacheState:
"""Class to manage global kernel caching state.""" """Class to manage global kernel caching state."""
_enabled = True _enabled = True
@classmethod @classmethod
...@@ -230,13 +229,11 @@ class Environment: ...@@ -230,13 +229,11 @@ class Environment:
TILELANG_TMP_DIR = EnvVar("TILELANG_TMP_DIR", os.path.join(TILELANG_CACHE_DIR.get(), "tmp")) TILELANG_TMP_DIR = EnvVar("TILELANG_TMP_DIR", os.path.join(TILELANG_CACHE_DIR.get(), "tmp"))
# Kernel Build options # Kernel Build options
TILELANG_PRINT_ON_COMPILATION = EnvVar("TILELANG_PRINT_ON_COMPILATION", TILELANG_PRINT_ON_COMPILATION = EnvVar("TILELANG_PRINT_ON_COMPILATION", "1") # print kernel name on compile
"1") # print kernel name on compile
TILELANG_DISABLE_CACHE = EnvVar( TILELANG_DISABLE_CACHE = EnvVar(
"TILELANG_DISABLE_CACHE", "TILELANG_DISABLE_CACHE", "0"
"0") # disable kernel cache, usually for unit testing / debugging, high priority ) # disable kernel cache, usually for unit testing / debugging, high priority
TILELANG_CLEAR_CACHE = EnvVar("TILELANG_CLEAR_CACHE", TILELANG_CLEAR_CACHE = EnvVar("TILELANG_CLEAR_CACHE", "0") # DEPRECATED! clear cache automatically if set
"0") # DEPRECATED! clear cache automatically if set
# Kernel selection options # Kernel selection options
# Default to GEMM v2; set to "1"/"true"/"yes"/"on" to force v1 # Default to GEMM v2; set to "1"/"true"/"yes"/"on" to force v1
...@@ -244,12 +241,9 @@ class Environment: ...@@ -244,12 +241,9 @@ class Environment:
# Auto-tuning settings # Auto-tuning settings
TILELANG_AUTO_TUNING_DISABLE_CACHE = EnvVar("TILELANG_AUTO_TUNING_DISABLE_CACHE", "0") TILELANG_AUTO_TUNING_DISABLE_CACHE = EnvVar("TILELANG_AUTO_TUNING_DISABLE_CACHE", "0")
TILELANG_AUTO_TUNING_CPU_UTILITIES = EnvVar("TILELANG_AUTO_TUNING_CPU_UTILITIES", TILELANG_AUTO_TUNING_CPU_UTILITIES = EnvVar("TILELANG_AUTO_TUNING_CPU_UTILITIES", "0.9") # percent of CPUs used
"0.9") # percent of CPUs used TILELANG_AUTO_TUNING_CPU_COUNTS = EnvVar("TILELANG_AUTO_TUNING_CPU_COUNTS", "-1") # -1 means auto
TILELANG_AUTO_TUNING_CPU_COUNTS = EnvVar("TILELANG_AUTO_TUNING_CPU_COUNTS", TILELANG_AUTO_TUNING_MAX_CPU_COUNT = EnvVar("TILELANG_AUTO_TUNING_MAX_CPU_COUNT", "-1") # -1 means no limit
"-1") # -1 means auto
TILELANG_AUTO_TUNING_MAX_CPU_COUNT = EnvVar("TILELANG_AUTO_TUNING_MAX_CPU_COUNT",
"-1") # -1 means no limit
# TVM integration # TVM integration
SKIP_LOADING_TILELANG_SO = EnvVar("SKIP_LOADING_TILELANG_SO", "0") SKIP_LOADING_TILELANG_SO = EnvVar("SKIP_LOADING_TILELANG_SO", "0")
...@@ -323,18 +317,18 @@ def prepend_pythonpath(path): ...@@ -323,18 +317,18 @@ def prepend_pythonpath(path):
if env.TVM_IMPORT_PYTHON_PATH is not None: if env.TVM_IMPORT_PYTHON_PATH is not None:
prepend_pythonpath(env.TVM_IMPORT_PYTHON_PATH) prepend_pythonpath(env.TVM_IMPORT_PYTHON_PATH)
else: else:
tvm_path = os.path.join(THIRD_PARTY_ROOT, 'tvm', 'python') tvm_path = os.path.join(THIRD_PARTY_ROOT, "tvm", "python")
assert os.path.exists(tvm_path), tvm_path assert os.path.exists(tvm_path), tvm_path
if tvm_path not in sys.path: if tvm_path not in sys.path:
prepend_pythonpath(tvm_path) prepend_pythonpath(tvm_path)
env.TVM_IMPORT_PYTHON_PATH = tvm_path env.TVM_IMPORT_PYTHON_PATH = tvm_path
# By default, the built TVM-related libraries are stored in TL_LIBS. # By default, the built TVM-related libraries are stored in TL_LIBS.
if os.environ.get("TVM_LIBRARY_PATH") is None: if os.environ.get("TVM_LIBRARY_PATH") is None:
os.environ['TVM_LIBRARY_PATH'] = env.TVM_LIBRARY_PATH = os.pathsep.join(TL_LIBS) os.environ["TVM_LIBRARY_PATH"] = env.TVM_LIBRARY_PATH = os.pathsep.join(TL_LIBS)
# Initialize CUTLASS paths # Initialize CUTLASS paths
if os.environ.get("TL_CUTLASS_PATH", None) is None: if os.environ.get("TL_CUTLASS_PATH", None) is None:
cutlass_inc_path = os.path.join(THIRD_PARTY_ROOT, 'cutlass', 'include') cutlass_inc_path = os.path.join(THIRD_PARTY_ROOT, "cutlass", "include")
if os.path.exists(cutlass_inc_path): if os.path.exists(cutlass_inc_path):
os.environ["TL_CUTLASS_PATH"] = env.CUTLASS_INCLUDE_DIR = cutlass_inc_path os.environ["TL_CUTLASS_PATH"] = env.CUTLASS_INCLUDE_DIR = cutlass_inc_path
else: else:
...@@ -342,7 +336,7 @@ if os.environ.get("TL_CUTLASS_PATH", None) is None: ...@@ -342,7 +336,7 @@ if os.environ.get("TL_CUTLASS_PATH", None) is None:
# Initialize COMPOSABLE_KERNEL paths # Initialize COMPOSABLE_KERNEL paths
if os.environ.get("TL_COMPOSABLE_KERNEL_PATH", None) is None: if os.environ.get("TL_COMPOSABLE_KERNEL_PATH", None) is None:
ck_inc_path = os.path.join(THIRD_PARTY_ROOT, 'composable_kernel', 'include') ck_inc_path = os.path.join(THIRD_PARTY_ROOT, "composable_kernel", "include")
if os.path.exists(ck_inc_path): if os.path.exists(ck_inc_path):
os.environ["TL_COMPOSABLE_KERNEL_PATH"] = env.COMPOSABLE_KERNEL_INCLUDE_DIR = ck_inc_path os.environ["TL_COMPOSABLE_KERNEL_PATH"] = env.COMPOSABLE_KERNEL_INCLUDE_DIR = ck_inc_path
else: else:
......
...@@ -4,7 +4,7 @@ import tilelang.language as T ...@@ -4,7 +4,7 @@ import tilelang.language as T
def shared_16x4_to_local_64x1_layout_A(i, j): def shared_16x4_to_local_64x1_layout_A(i, j):
thread_id = (j * 16 + i) thread_id = j * 16 + i
return thread_id, convert(0) return thread_id, convert(0)
...@@ -15,7 +15,7 @@ def thread_id_shared_access_64x1_to_16x4_layout_A(thread_id, local_id): ...@@ -15,7 +15,7 @@ def thread_id_shared_access_64x1_to_16x4_layout_A(thread_id, local_id):
def shared_4x16_to_local_64x1_layout_B(i, j): def shared_4x16_to_local_64x1_layout_B(i, j):
thread_id = (i * 16 + j) thread_id = i * 16 + j
return thread_id, convert(0) return thread_id, convert(0)
...@@ -27,7 +27,7 @@ def thread_id_shared_access_64x1_to_4x16_layout_B(thread_id, local_id): ...@@ -27,7 +27,7 @@ def thread_id_shared_access_64x1_to_4x16_layout_B(thread_id, local_id):
def shared_16x16_to_local_64x4_layout_C(i, j): def shared_16x16_to_local_64x4_layout_C(i, j):
thread_id = j + (i // 4) * 16 thread_id = j + (i // 4) * 16
local = (i % 4) local = i % 4
return thread_id, local return thread_id, local
...@@ -45,7 +45,7 @@ def thread_id_shared_access_64x4_to_16x16_layout_A(thread_id, local_id): ...@@ -45,7 +45,7 @@ def thread_id_shared_access_64x4_to_16x16_layout_A(thread_id, local_id):
def shared_16x16_to_local_64x4_layout_A(i, j): def shared_16x16_to_local_64x4_layout_A(i, j):
thread_id = i + 16 * (j // 4) thread_id = i + 16 * (j // 4)
local = (j % 4) local = j % 4
return thread_id, local return thread_id, local
...@@ -57,7 +57,7 @@ def thread_id_shared_access_64x4_to_16x16_layout_B(thread_id, local_id): ...@@ -57,7 +57,7 @@ def thread_id_shared_access_64x4_to_16x16_layout_B(thread_id, local_id):
def shared_16x16_to_local_64x4_layout_B(i, j): def shared_16x16_to_local_64x4_layout_B(i, j):
thread_id = j + (i // 4) * 16 thread_id = j + (i // 4) * 16
local = (i % 4) local = i % 4
return thread_id, local return thread_id, local
...@@ -87,7 +87,7 @@ def thread_id_shared_access_64x8_to_16x32_layout_A(thread_id, local_id): ...@@ -87,7 +87,7 @@ def thread_id_shared_access_64x8_to_16x32_layout_A(thread_id, local_id):
def shared_16x32_to_local_64x8_layout_A(i, j): def shared_16x32_to_local_64x8_layout_A(i, j):
thread_id = i + 16 * (j // 8) thread_id = i + 16 * (j // 8)
local = (j % 8) local = j % 8
return thread_id, local return thread_id, local
...@@ -99,7 +99,7 @@ def thread_id_shared_access_64x8_to_16x32_layout_B(thread_id, local_id): ...@@ -99,7 +99,7 @@ def thread_id_shared_access_64x8_to_16x32_layout_B(thread_id, local_id):
def shared_16x32_to_local_64x8_layout_B(i, j): def shared_16x32_to_local_64x8_layout_B(i, j):
thread_id = j + (i // 8) * 16 thread_id = j + (i // 8) * 16
local = (i % 8) local = i % 8
return thread_id, local return thread_id, local
...@@ -111,7 +111,7 @@ def thread_id_shared_access_64x16_to_16x64_layout_A(thread_id, local_id): ...@@ -111,7 +111,7 @@ def thread_id_shared_access_64x16_to_16x64_layout_A(thread_id, local_id):
def shared_16x64_to_local_64x16_layout_A(i, j): def shared_16x64_to_local_64x16_layout_A(i, j):
thread_id = i + 16 * (j // 16) thread_id = i + 16 * (j // 16)
local = (j % 16) local = j % 16
return thread_id, local return thread_id, local
...@@ -123,7 +123,7 @@ def thread_id_shared_access_64x16_to_16x64_layout_B(thread_id, local_id): ...@@ -123,7 +123,7 @@ def thread_id_shared_access_64x16_to_16x64_layout_B(thread_id, local_id):
def shared_16x64_to_local_64x16_layout_B(i, j): def shared_16x64_to_local_64x16_layout_B(i, j):
thread_id = i + 16 * (j // 16) thread_id = i + 16 * (j // 16)
local = (j % 16) local = j % 16
return thread_id, local return thread_id, local
......
...@@ -6,7 +6,7 @@ from tvm import tir ...@@ -6,7 +6,7 @@ from tvm import tir
from tvm.ir import Range from tvm.ir import Range
from tvm.tir import PrimExpr, IndexMap, Buffer, Var, BufferRegion, BufferLoad from tvm.tir import PrimExpr, IndexMap, Buffer, Var, BufferRegion, BufferLoad
from tvm.runtime import convert from tvm.runtime import convert
from .utils import (mfma_store_index_map) from .utils import mfma_store_index_map
from typing import Literal, Callable from typing import Literal, Callable
from tilelang.utils import is_fragment from tilelang.utils import is_fragment
...@@ -101,7 +101,7 @@ class MatrixCoreIntrinEmitter: ...@@ -101,7 +101,7 @@ class MatrixCoreIntrinEmitter:
self.warp_rows = warp_row_tiles // self.micro_size_x self.warp_rows = warp_row_tiles // self.micro_size_x
self.warp_cols = warp_col_tiles // self.micro_size_y self.warp_cols = warp_col_tiles // self.micro_size_y
self.reduce_k = reduce_k self.reduce_k = reduce_k
self.threads = (self.WARP_SIZE * (block_row_warps * block_col_warps) * reduce_k) self.threads = self.WARP_SIZE * (block_row_warps * block_col_warps) * reduce_k
self.num_elems_per_byte = num_elems_per_byte self.num_elems_per_byte = num_elems_per_byte
self.thread_var = thread_var self.thread_var = thread_var
...@@ -132,12 +132,7 @@ class MatrixCoreIntrinEmitter: ...@@ -132,12 +132,7 @@ class MatrixCoreIntrinEmitter:
def _initialize_mfma_prefix(self, k_dim=16): def _initialize_mfma_prefix(self, k_dim=16):
in_dtype, out_dtype = self.a_dtype, self.accum_dtype in_dtype, out_dtype = self.a_dtype, self.accum_dtype
M_DIM, N_DIM = self.M_DIM, self.N_DIM M_DIM, N_DIM = self.M_DIM, self.N_DIM
out_dtype_abbrv = { out_dtype_abbrv = {"float16": "f16", "float32": "f32", "int8": "i8", "int32": "i32"}[out_dtype]
"float16": "f16",
"float32": "f32",
"int8": "i8",
"int32": "i32"
}[out_dtype]
in_dtype_abbrv = { in_dtype_abbrv = {
"bfloat16": "bf16", "bfloat16": "bf16",
...@@ -176,7 +171,6 @@ class MatrixCoreIntrinEmitter: ...@@ -176,7 +171,6 @@ class MatrixCoreIntrinEmitter:
self.b_preshuffle = b_preshuffle self.b_preshuffle = b_preshuffle
def get_ldmatrix_index_map(self, is_b=False): def get_ldmatrix_index_map(self, is_b=False):
k_dim = self.k_dim * self.k_pack k_dim = self.k_dim * self.k_pack
transposed = self.a_transposed if not is_b else self.b_transposed transposed = self.a_transposed if not is_b else self.b_transposed
if k_dim == 4: if k_dim == 4:
...@@ -184,28 +178,42 @@ class MatrixCoreIntrinEmitter: ...@@ -184,28 +178,42 @@ class MatrixCoreIntrinEmitter:
reverse_index_map = thread_id_shared_access_64x1_to_16x4_layout_A reverse_index_map = thread_id_shared_access_64x1_to_16x4_layout_A
if is_b: if is_b:
index_map = shared_16x4_to_local_64x1_layout_A if transposed else shared_4x16_to_local_64x1_layout_B index_map = shared_16x4_to_local_64x1_layout_A if transposed else shared_4x16_to_local_64x1_layout_B
reverse_index_map = thread_id_shared_access_64x1_to_16x4_layout_A if transposed else thread_id_shared_access_64x1_to_4x16_layout_B reverse_index_map = (
thread_id_shared_access_64x1_to_16x4_layout_A if transposed else thread_id_shared_access_64x1_to_4x16_layout_B
)
elif k_dim == 16: elif k_dim == 16:
index_map = shared_16x16_to_local_64x4_layout_B if transposed else shared_16x16_to_local_64x4_layout_A index_map = shared_16x16_to_local_64x4_layout_B if transposed else shared_16x16_to_local_64x4_layout_A
reverse_index_map = thread_id_shared_access_64x4_to_16x16_layout_B if transposed else thread_id_shared_access_64x4_to_16x16_layout_A reverse_index_map = (
thread_id_shared_access_64x4_to_16x16_layout_B if transposed else thread_id_shared_access_64x4_to_16x16_layout_A
)
if is_b: if is_b:
index_map = shared_16x16_to_local_64x4_layout_A if transposed else shared_16x16_to_local_64x4_layout_B index_map = shared_16x16_to_local_64x4_layout_A if transposed else shared_16x16_to_local_64x4_layout_B
reverse_index_map = thread_id_shared_access_64x4_to_16x16_layout_A if transposed else thread_id_shared_access_64x4_to_16x16_layout_B reverse_index_map = (
thread_id_shared_access_64x4_to_16x16_layout_A if transposed else thread_id_shared_access_64x4_to_16x16_layout_B
)
elif k_dim == 32: elif k_dim == 32:
index_map = shared_16x32_to_local_64x8_layout_B if transposed else shared_16x32_to_local_64x8_layout_A index_map = shared_16x32_to_local_64x8_layout_B if transposed else shared_16x32_to_local_64x8_layout_A
reverse_index_map = thread_id_shared_access_64x8_to_16x32_layout_B if transposed else thread_id_shared_access_64x8_to_16x32_layout_A reverse_index_map = (
thread_id_shared_access_64x8_to_16x32_layout_B if transposed else thread_id_shared_access_64x8_to_16x32_layout_A
)
if is_b: if is_b:
index_map = shared_16x32_to_local_64x8_layout_A if transposed else shared_16x32_to_local_64x8_layout_B index_map = shared_16x32_to_local_64x8_layout_A if transposed else shared_16x32_to_local_64x8_layout_B
reverse_index_map = thread_id_shared_access_64x8_to_16x32_layout_A if transposed else thread_id_shared_access_64x8_to_16x32_layout_B reverse_index_map = (
thread_id_shared_access_64x8_to_16x32_layout_A if transposed else thread_id_shared_access_64x8_to_16x32_layout_B
)
elif k_dim == 64: elif k_dim == 64:
index_map = shared_16x64_to_local_64x16_layout_B if transposed else shared_16x64_to_local_64x16_layout_A index_map = shared_16x64_to_local_64x16_layout_B if transposed else shared_16x64_to_local_64x16_layout_A
reverse_index_map = thread_id_shared_access_64x16_to_16x64_layout_B if transposed else thread_id_shared_access_64x16_to_16x64_layout_A reverse_index_map = (
thread_id_shared_access_64x16_to_16x64_layout_B if transposed else thread_id_shared_access_64x16_to_16x64_layout_A
)
if is_b: if is_b:
index_map = shared_16x64_to_local_64x16_layout_A if transposed else shared_16x64_to_local_64x16_layout_B index_map = shared_16x64_to_local_64x16_layout_A if transposed else shared_16x64_to_local_64x16_layout_B
reverse_index_map = thread_id_shared_access_64x16_to_16x64_layout_A if transposed else thread_id_shared_access_64x16_to_16x64_layout_B reverse_index_map = (
thread_id_shared_access_64x16_to_16x64_layout_A if transposed else thread_id_shared_access_64x16_to_16x64_layout_B
)
else: else:
raise ValueError("k_dim must be 4 or 16 or 32 or 64 currently") raise ValueError("k_dim must be 4 or 16 or 32 or 64 currently")
...@@ -227,14 +235,12 @@ class MatrixCoreIntrinEmitter: ...@@ -227,14 +235,12 @@ class MatrixCoreIntrinEmitter:
else: else:
return self.thread_var return self.thread_var
def extract_thread_binding(self, def extract_thread_binding(self, thread_id, is_m_first=None) -> tuple[PrimExpr, PrimExpr, PrimExpr]:
thread_id, """
is_m_first=None) -> tuple[PrimExpr, PrimExpr, PrimExpr]: is_m_first: True if the thread binding is in the form of (tx, warp_n, warp_m)
''' which represents [warp_size, block_row_warps (split n), block_col_warps (split m)]
is_m_first: True if the thread binding is in the form of (tx, warp_n, warp_m) Otherwise, it is in the form of [warp_size, block_col_warps (split m), block_row_warps (split n)]
which represents [warp_size, block_row_warps (split n), block_col_warps (split m)] """
Otherwise, it is in the form of [warp_size, block_col_warps (split m), block_row_warps (split n)]
'''
WARP_SIZE = self.WARP_SIZE WARP_SIZE = self.WARP_SIZE
block_row_warps = self.block_row_warps block_row_warps = self.block_row_warps
block_col_warps = self.block_col_warps block_col_warps = self.block_col_warps
...@@ -244,16 +250,18 @@ class MatrixCoreIntrinEmitter: ...@@ -244,16 +250,18 @@ class MatrixCoreIntrinEmitter:
is_m_first = self.is_m_first is_m_first = self.is_m_first
if is_m_first: if is_m_first:
lane_id, warp_n, warp_m = thread_id % WARP_SIZE, ( lane_id, warp_n, warp_m = (
thread_id // thread_id % WARP_SIZE,
WARP_SIZE) % block_col_warps, (thread_id // (thread_id // WARP_SIZE) % block_col_warps,
(WARP_SIZE * block_col_warps)) % block_row_warps, (thread_id // (WARP_SIZE * block_col_warps)) % block_row_warps,
)
return lane_id, warp_n, warp_m return lane_id, warp_n, warp_m
else: else:
lane_id, warp_m, warp_n = thread_id % WARP_SIZE, ( lane_id, warp_m, warp_n = (
thread_id // thread_id % WARP_SIZE,
WARP_SIZE) % block_row_warps, (thread_id // (thread_id // WARP_SIZE) % block_row_warps,
(WARP_SIZE * block_row_warps)) % block_col_warps, (thread_id // (WARP_SIZE * block_row_warps)) % block_col_warps,
)
return lane_id, warp_n, warp_m return lane_id, warp_n, warp_m
def ldmatrix_a(self, A_local_buf, A_shared_buf: Buffer | BufferRegion, ki, rk=0): def ldmatrix_a(self, A_local_buf, A_shared_buf: Buffer | BufferRegion, ki, rk=0):
...@@ -287,18 +295,14 @@ class MatrixCoreIntrinEmitter: ...@@ -287,18 +295,14 @@ class MatrixCoreIntrinEmitter:
for i in T.serial(warp_rows): for i in T.serial(warp_rows):
for local_id in T.vectorized(k_pack * local_size_a): for local_id in T.vectorized(k_pack * local_size_a):
row, col = T.meta_var(reverse_index_map(tx, local_id)) row, col = T.meta_var(reverse_index_map(tx, local_id))
l, r = (rk * chunk + ki * (k_pack * micro_size_k), l, r = (rk * chunk + ki * (k_pack * micro_size_k), warp_m * warp_row_tiles + i * micro_size_x)
warp_m * warp_row_tiles + i * micro_size_x) A_local_buf[i * k_pack * local_size_a + local_id] = A_buf[A_base0 + l + row, A_base1 + r + col]
A_local_buf[i * k_pack * local_size_a + local_id] = A_buf[A_base0 + l + row,
A_base1 + r + col]
else: else:
for i in T.serial(warp_rows): for i in T.serial(warp_rows):
for local_id in T.vectorized(k_pack * local_size_a): for local_id in T.vectorized(k_pack * local_size_a):
row, col = T.meta_var(reverse_index_map(tx, local_id)) row, col = T.meta_var(reverse_index_map(tx, local_id))
l, r = (warp_m * warp_row_tiles + i * micro_size_x, l, r = (warp_m * warp_row_tiles + i * micro_size_x, rk * chunk + ki * (k_pack * micro_size_k))
rk * chunk + ki * (k_pack * micro_size_k)) A_local_buf[i * k_pack * local_size_a + local_id] = A_buf[A_base0 + l + row, A_base1 + r + col]
A_local_buf[i * k_pack * local_size_a + local_id] = A_buf[A_base0 + l + row,
A_base1 + r + col]
return _warp_ldmatrix_a(A_local_buf, A_shared_buf, ki, thread_binding, rk) return _warp_ldmatrix_a(A_local_buf, A_shared_buf, ki, thread_binding, rk)
...@@ -337,8 +341,7 @@ class MatrixCoreIntrinEmitter: ...@@ -337,8 +341,7 @@ class MatrixCoreIntrinEmitter:
warp_n * warp_col_tiles + j * micro_size_y, warp_n * warp_col_tiles + j * micro_size_y,
rk * chunk + ki * (k_pack * micro_size_k), rk * chunk + ki * (k_pack * micro_size_k),
) )
B_local_buf[j * k_pack * local_size_b + local_id] = B_buf[B_base0 + l + row, B_local_buf[j * k_pack * local_size_b + local_id] = B_buf[B_base0 + l + row, B_base1 + r + col]
B_base1 + r + col]
else: else:
for j in T.serial(warp_cols): for j in T.serial(warp_cols):
...@@ -348,16 +351,11 @@ class MatrixCoreIntrinEmitter: ...@@ -348,16 +351,11 @@ class MatrixCoreIntrinEmitter:
rk * chunk + ki * (k_pack * micro_size_k), rk * chunk + ki * (k_pack * micro_size_k),
warp_n * warp_col_tiles + j * micro_size_y, warp_n * warp_col_tiles + j * micro_size_y,
) )
B_local_buf[j * k_pack * local_size_b + local_id] = B_buf[B_base0 + l + row, B_local_buf[j * k_pack * local_size_b + local_id] = B_buf[B_base0 + l + row, B_base1 + r + col]
B_base1 + r + col]
return _warp_ldmatrix_b(B_local_buf, B_shared_buf, ki, thread_binding, rk) return _warp_ldmatrix_b(B_local_buf, B_shared_buf, ki, thread_binding, rk)
def mfma(self, def mfma(self, A_local_buf: Buffer, B_local_buf: Buffer, C_local_buf: Buffer, k_inner: PrimExpr | None = 0):
A_local_buf: Buffer,
B_local_buf: Buffer,
C_local_buf: Buffer,
k_inner: PrimExpr | None = 0):
warp_rows = self.warp_rows warp_rows = self.warp_rows
warp_cols = self.warp_cols warp_cols = self.warp_cols
local_size_a = self.local_size_a local_size_a = self.local_size_a
...@@ -421,14 +419,13 @@ class MatrixCoreIntrinEmitter: ...@@ -421,14 +419,13 @@ class MatrixCoreIntrinEmitter:
for local_id in T.vectorized(local_size_out): for local_id in T.vectorized(local_size_out):
row, col = T.meta_var(mfma_store_index_map(tx, local_id)) row, col = T.meta_var(mfma_store_index_map(tx, local_id))
if C_buf_dims == 2: if C_buf_dims == 2:
C_buf[(warp_m * warp_rows + i) * M_DIM + row, C_buf[(warp_m * warp_rows + i) * M_DIM + row, (warp_n * warp_cols + j) * N_DIM + col] = C_local_buf[
(warp_n * warp_cols + j) * N_DIM + i * (warp_cols * local_size_out) + j * local_size_out + local_id
col] = C_local_buf[i * (warp_cols * local_size_out) + ]
j * local_size_out + local_id]
else: else:
C_buf[warp_m * warp_rows + i, warp_n * warp_cols + j, row, C_buf[warp_m * warp_rows + i, warp_n * warp_cols + j, row, col] = C_local_buf[
col] = C_local_buf[i * warp_cols * local_size_out + i * warp_cols * local_size_out + j * local_size_out + local_id
j * local_size_out + local_id] ]
@T.macro @T.macro
def _warp_stmatrix_global(C_local_buf, C_buf, thread_binding): def _warp_stmatrix_global(C_local_buf, C_buf, thread_binding):
...@@ -436,18 +433,17 @@ class MatrixCoreIntrinEmitter: ...@@ -436,18 +433,17 @@ class MatrixCoreIntrinEmitter:
for i, j in T.grid(warp_rows, warp_cols): for i, j in T.grid(warp_rows, warp_cols):
for local_id in T.vectorized(local_size_out): for local_id in T.vectorized(local_size_out):
row, col = T.meta_var(mfma_store_index_map(tx, local_id)) row, col = T.meta_var(mfma_store_index_map(tx, local_id))
C_buf[(pid_m * BLOCK_M + warp_m * warp_rows + i) * M_DIM + row, C_buf[
(pid_n * BLOCK_N + warp_n * warp_cols + j) * N_DIM + (pid_m * BLOCK_M + warp_m * warp_rows + i) * M_DIM + row, (pid_n * BLOCK_N + warp_n * warp_cols + j) * N_DIM + col
col] = C_local_buf[i * warp_cols * local_size_out + j * local_size_out + ] = C_local_buf[i * warp_cols * local_size_out + j * local_size_out + local_id]
local_id]
return (
return _warp_stmatrix_global(C_local_buf, C_buf, _warp_stmatrix_global(C_local_buf, C_buf, thread_binding)
thread_binding) if is_global else _warp_stmatrix_shared( if is_global
C_local_buf, C_buf, thread_binding) else _warp_stmatrix_shared(C_local_buf, C_buf, thread_binding)
)
def make_mfma_load_layout(self,
local_buf: Buffer, def make_mfma_load_layout(self, local_buf: Buffer, matrix: Literal["A", "B"] = "A") -> T.Fragment:
matrix: Literal["A", "B"] = "A") -> T.Fragment:
""" """
Create a layout function for storing MFMA results into a fragment buffer. Create a layout function for storing MFMA results into a fragment buffer.
...@@ -468,6 +464,7 @@ class MatrixCoreIntrinEmitter: ...@@ -468,6 +464,7 @@ class MatrixCoreIntrinEmitter:
If `local_buf` is not detected to be a fragment buffer. If `local_buf` is not detected to be a fragment buffer.
""" """
from tilelang.utils import is_fragment from tilelang.utils import is_fragment
assert matrix in ["A", "B"], "matrix should be either A or B" assert matrix in ["A", "B"], "matrix should be either A or B"
matrix_is_a: bool = matrix == "A" matrix_is_a: bool = matrix == "A"
matrix_is_b: bool = matrix == "B" matrix_is_b: bool = matrix == "B"
...@@ -506,11 +503,9 @@ class MatrixCoreIntrinEmitter: ...@@ -506,11 +503,9 @@ class MatrixCoreIntrinEmitter:
transform_func: Callable = None transform_func: Callable = None
if matrix_is_a: if matrix_is_a:
transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a( transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a(j, i)
j, i)
elif matrix_is_b: elif matrix_is_b:
transform_func = transform_func_sr_b if is_sr_axis_order else lambda i, j: transform_func_sr_b( transform_func = transform_func_sr_b if is_sr_axis_order else lambda i, j: transform_func_sr_b(j, i)
j, i)
else: else:
raise ValueError(f"Unsupported matrix {matrix}") raise ValueError(f"Unsupported matrix {matrix}")
...@@ -543,8 +538,7 @@ class MatrixCoreIntrinEmitter: ...@@ -543,8 +538,7 @@ class MatrixCoreIntrinEmitter:
return local_id return local_id
base_fragment = T.Fragment( base_fragment = T.Fragment(
[micro_size_s, micro_size_r * [micro_size_s, micro_size_r * self.k_pack] if is_sr_axis_order else [micro_size_r * self.k_pack, micro_size_s],
self.k_pack] if is_sr_axis_order else [micro_size_r * self.k_pack, micro_size_s],
forward_thread_fn=forward_thread, forward_thread_fn=forward_thread,
forward_index_fn=forward_index, forward_index_fn=forward_index,
) )
...@@ -558,31 +552,19 @@ class MatrixCoreIntrinEmitter: ...@@ -558,31 +552,19 @@ class MatrixCoreIntrinEmitter:
replicate = block_col_warps if matrix_is_a else block_row_warps replicate = block_col_warps if matrix_is_a else block_row_warps
if is_sr_axis_order: if is_sr_axis_order:
warp_fragment = base_fragment.repeat([warp_s, warp_r], warp_fragment = base_fragment.repeat([warp_s, warp_r], repeat_on_thread=False, lower_dim_first=False)
repeat_on_thread=False,
lower_dim_first=False)
if matrix_is_a: if matrix_is_a:
block_fragment = warp_fragment.repeat([block_s, 1], block_fragment = warp_fragment.repeat([block_s, 1], repeat_on_thread=True, lower_dim_first=True).replicate(replicate)
repeat_on_thread=True,
lower_dim_first=True).replicate(replicate)
elif matrix_is_b: elif matrix_is_b:
block_fragment = warp_fragment.replicate(replicate).repeat([block_s, 1], block_fragment = warp_fragment.replicate(replicate).repeat([block_s, 1], repeat_on_thread=True, lower_dim_first=True)
repeat_on_thread=True,
lower_dim_first=True)
else: else:
raise ValueError(f"Unsupported matrix type {matrix}") raise ValueError(f"Unsupported matrix type {matrix}")
else: else:
warp_fragment = base_fragment.repeat([warp_r, warp_s], warp_fragment = base_fragment.repeat([warp_r, warp_s], repeat_on_thread=False, lower_dim_first=True)
repeat_on_thread=False,
lower_dim_first=True)
if matrix_is_a: if matrix_is_a:
block_fragment = warp_fragment.repeat([1, block_s], block_fragment = warp_fragment.repeat([1, block_s], repeat_on_thread=True, lower_dim_first=True).replicate(replicate)
repeat_on_thread=True,
lower_dim_first=True).replicate(replicate)
elif matrix_is_b: elif matrix_is_b:
block_fragment = warp_fragment.replicate(replicate).repeat([1, block_s], block_fragment = warp_fragment.replicate(replicate).repeat([1, block_s], repeat_on_thread=True, lower_dim_first=True)
repeat_on_thread=True,
lower_dim_first=True)
else: else:
raise ValueError(f"Unsupported matrix type {matrix}") raise ValueError(f"Unsupported matrix type {matrix}")
...@@ -686,7 +668,6 @@ class MatrixCoreIntrinEmitter: ...@@ -686,7 +668,6 @@ class MatrixCoreIntrinEmitter:
class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter): class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter):
def __init__( def __init__(
self, self,
a_dtype: str = "float16", a_dtype: str = "float16",
...@@ -792,20 +773,20 @@ class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter): ...@@ -792,20 +773,20 @@ class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter):
rk * (chunk // micro_size_k) + ki, rk * (chunk // micro_size_k) + ki,
warp_m * warp_rows + i, warp_m * warp_rows + i,
) )
A_local_buf[i * k_pack * local_size_a + local_id] = A_shared_buf[l, r, row, A_local_buf[i * k_pack * local_size_a + local_id] = A_shared_buf[l, r, row, col]
col]
else: else:
print(self.a_preshuffle) print(self.a_preshuffle)
for i in T.serial(warp_rows): for i in T.serial(warp_rows):
for local_id in T.vectorized(k_pack * local_size_a): for local_id in T.vectorized(k_pack * local_size_a):
row, col = T.meta_var(reverse_index_map(tx, local_id)) row, col = T.meta_var(reverse_index_map(tx, local_id))
l, r = (warp_m * warp_rows + i, rk * (chunk // micro_size_k) + ki) l, r = (warp_m * warp_rows + i, rk * (chunk // micro_size_k) + ki)
A_local_buf[i * k_pack * local_size_a + local_id] = A_shared_buf[l, r, row, A_local_buf[i * k_pack * local_size_a + local_id] = A_shared_buf[l, r, row, col]
col]
return _warp_ldmatrix_a_global(A_local_buf, A_buf, ki, thread_binding, return (
rk) if is_global else _warp_ldmatrix_a_shared( _warp_ldmatrix_a_global(A_local_buf, A_buf, ki, thread_binding, rk)
A_local_buf, A_buf, ki, thread_binding, rk) if is_global
else _warp_ldmatrix_a_shared(A_local_buf, A_buf, ki, thread_binding, rk)
)
def ldmatrix_b(self, B_local_buf, B_buf, ki, rk=0, pid_m=None, pid_n=None): def ldmatrix_b(self, B_local_buf, B_buf, ki, rk=0, pid_m=None, pid_n=None):
warp_cols = self.warp_cols warp_cols = self.warp_cols
...@@ -867,8 +848,7 @@ class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter): ...@@ -867,8 +848,7 @@ class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter):
warp_n * warp_cols + j, warp_n * warp_cols + j,
rk * (chunk // micro_size_k) + ki, rk * (chunk // micro_size_k) + ki,
) )
B_local_buf[j * k_pack * local_size_b + local_id] = B_shared_buf[l, r, row, B_local_buf[j * k_pack * local_size_b + local_id] = B_shared_buf[l, r, row, col]
col]
else: else:
for j in T.serial(warp_cols): for j in T.serial(warp_cols):
for local_id in T.vectorized(k_pack * local_size_b): for local_id in T.vectorized(k_pack * local_size_b):
...@@ -877,9 +857,10 @@ class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter): ...@@ -877,9 +857,10 @@ class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter):
rk * (chunk // micro_size_k) + ki, rk * (chunk // micro_size_k) + ki,
warp_n * warp_cols + j, warp_n * warp_cols + j,
) )
B_local_buf[j * k_pack * local_size_b + local_id] = B_shared_buf[l, r, row, B_local_buf[j * k_pack * local_size_b + local_id] = B_shared_buf[l, r, row, col]
col]
return _warp_ldmatrix_b_global(B_local_buf, B_buf, ki, thread_binding, return (
rk) if is_global else _warp_ldmatrix_b_shared( _warp_ldmatrix_b_global(B_local_buf, B_buf, ki, thread_binding, rk)
B_local_buf, B_buf, ki, thread_binding, rk) if is_global
else _warp_ldmatrix_b_shared(B_local_buf, B_buf, ki, thread_binding, rk)
)
...@@ -153,14 +153,14 @@ def mma_load_a_32x16_to_shared_16x32_layout(thread_id, local_id): ...@@ -153,14 +153,14 @@ def mma_load_a_32x16_to_shared_16x32_layout(thread_id, local_id):
def mma_load_a_32x8_to_shared_16x16_layout(thread_id, local_id): def mma_load_a_32x8_to_shared_16x16_layout(thread_id, local_id):
""" """
groupID = %laneid >> 2 groupID = %laneid >> 2
threadID_in_group = %laneid % 4 threadID_in_group = %laneid % 4
row = groupID for ai where 0 <= i < 2 || 4 <= i < 6 row = groupID for ai where 0 <= i < 2 || 4 <= i < 6
groupID + 8 Otherwise groupID + 8 Otherwise
col = (threadID_in_group * 2) + (i & 0x1) for ai where i < 4 col = (threadID_in_group * 2) + (i & 0x1) for ai where i < 4
(threadID_in_group * 2) + (i & 0x1) + 8 for ai where i >= 4 (threadID_in_group * 2) + (i & 0x1) + 8 for ai where i >= 4
""" """
row = (thread_id // 4) + 8 * (local_id % 4 // 2) row = (thread_id // 4) + 8 * (local_id % 4 // 2)
col = (thread_id % 4) * 2 + (local_id % 2) + 8 * (local_id // 4) col = (thread_id % 4) * 2 + (local_id % 2) + 8 * (local_id // 4)
...@@ -175,13 +175,13 @@ def mma_load_b_32x16_to_shared_16x32_layout(thread_id, local_id): ...@@ -175,13 +175,13 @@ def mma_load_b_32x16_to_shared_16x32_layout(thread_id, local_id):
def mma_load_b_32x8_to_shared_16x16_layout(thread_id, local_id): def mma_load_b_32x8_to_shared_16x16_layout(thread_id, local_id):
""" """
groupID = %laneid >> 2 groupID = %laneid >> 2
threadID_in_group = %laneid % 4 threadID_in_group = %laneid % 4
row = (threadID_in_group * 2) + (i & 0x1) for bi where i < 2 row = (threadID_in_group * 2) + (i & 0x1) for bi where i < 2
(threadID_in_group * 2) + (i & 0x1) + 8 for bi where i >= 2 (threadID_in_group * 2) + (i & 0x1) + 8 for bi where i >= 2
col = groupID col = groupID
""" """
col = (thread_id % 4) * 2 + ((local_id % 4) % 2) + ((local_id % 4) // 2) * 8 col = (thread_id % 4) * 2 + ((local_id % 4) % 2) + ((local_id % 4) // 2) * 8
row = (thread_id // 4) + 8 * (local_id // 4) row = (thread_id // 4) + 8 * (local_id // 4)
......
...@@ -191,6 +191,7 @@ class TensorCoreIntrinEmitter: ...@@ -191,6 +191,7 @@ class TensorCoreIntrinEmitter:
def get_store_index_map(self, inverse: bool = False) -> IndexMap: def get_store_index_map(self, inverse: bool = False) -> IndexMap:
from .utils import mma_store_index_map, mma_store_index_map_fp64 from .utils import mma_store_index_map, mma_store_index_map_fp64
warp_size, local_size_c = self.WARP_SIZE, self.local_size_out warp_size, local_size_c = self.WARP_SIZE, self.local_size_out
if DataType(self.accum_dtype).bits == 64: if DataType(self.accum_dtype).bits == 64:
index_map = IndexMap.from_func(mma_store_index_map_fp64, index_dtype="int32") index_map = IndexMap.from_func(mma_store_index_map_fp64, index_dtype="int32")
...@@ -201,10 +202,7 @@ class TensorCoreIntrinEmitter: ...@@ -201,10 +202,7 @@ class TensorCoreIntrinEmitter:
inverse_index_map = index_map.inverse([warp_size, local_size_c]) inverse_index_map = index_map.inverse([warp_size, local_size_c])
return inverse_index_map return inverse_index_map
def extract_thread_binding( def extract_thread_binding(self, thread_id: PrimExpr, is_m_first: bool | None = None) -> tuple[PrimExpr, PrimExpr, PrimExpr]:
self,
thread_id: PrimExpr,
is_m_first: bool | None = None) -> tuple[PrimExpr, PrimExpr, PrimExpr]:
""" """
is_m_first: True if the thread binding is in the form of (tx, warp_n, warp_m) is_m_first: True if the thread binding is in the form of (tx, warp_n, warp_m)
which represents [warp_size, block_row_warps (split n), block_col_warps (split m)] which represents [warp_size, block_row_warps (split n), block_col_warps (split m)]
...@@ -233,11 +231,7 @@ class TensorCoreIntrinEmitter: ...@@ -233,11 +231,7 @@ class TensorCoreIntrinEmitter:
) )
return lane_id, warp_n, warp_m return lane_id, warp_n, warp_m
def ldmatrix_a(self, def ldmatrix_a(self, A_local_buf: Buffer, A_shared_buf: Buffer | BufferRegion, ki: PrimExpr, rk: PrimExpr | None = 0):
A_local_buf: Buffer,
A_shared_buf: Buffer | BufferRegion,
ki: PrimExpr,
rk: PrimExpr | None = 0):
# Fast path for fp64: no ldmatrix support, do direct per-lane loads # Fast path for fp64: no ldmatrix support, do direct per-lane loads
if DataType(self.a_dtype).bits == 64: if DataType(self.a_dtype).bits == 64:
warp_row_tiles = self.warp_row_tiles warp_row_tiles = self.warp_row_tiles
...@@ -324,9 +318,7 @@ class TensorCoreIntrinEmitter: ...@@ -324,9 +318,7 @@ class TensorCoreIntrinEmitter:
for i in T.serial(warp_rows): for i in T.serial(warp_rows):
# Assign A_shared_buf_elem # Assign A_shared_buf_elem
wi, wk = warp_m * warp_row_tiles + i * micro_size_x, rk * chunk + ki * micro_size_k wi, wk = warp_m * warp_row_tiles + i * micro_size_x, rk * chunk + ki * micro_size_k
A_shared_buf_elem = A_buf[A_base0 + wk, A_shared_buf_elem = A_buf[A_base0 + wk, A_base1 + wi] if a_transposed else A_buf[A_base0 + wi, A_base1 + wk]
A_base1 + wi] if a_transposed else A_buf[A_base0 + wi,
A_base1 + wk]
if ldmatrix_available: if ldmatrix_available:
T.ptx_ldmatrix( T.ptx_ldmatrix(
...@@ -343,20 +335,13 @@ class TensorCoreIntrinEmitter: ...@@ -343,20 +335,13 @@ class TensorCoreIntrinEmitter:
for j in T.serial(local_size_a): for j in T.serial(local_size_a):
mi, mk = mma_load_layout(tx, j) mi, mk = mma_load_layout(tx, j)
if a_transposed: if a_transposed:
A_local_buf[i * local_size_a + j] = A_buf[A_base0 + wk + mk, A_local_buf[i * local_size_a + j] = A_buf[A_base0 + wk + mk, A_base1 + wi + mi]
A_base1 + wi + mi]
else: else:
A_local_buf[i * local_size_a + j] = A_buf[A_base0 + wi + mi, A_local_buf[i * local_size_a + j] = A_buf[A_base0 + wi + mi, A_base1 + wk + mk]
A_base1 + wk + mk]
return _warp_ldmatrix_a(A_local_buf, A_region, ki, thread_binding, rk) return _warp_ldmatrix_a(A_local_buf, A_region, ki, thread_binding, rk)
def ldmatrix_b(self, def ldmatrix_b(self, B_local_buf: Buffer, B_shared_buf: Buffer | BufferRegion, ki: PrimExpr, rk: PrimExpr | None = 0):
B_local_buf: Buffer,
B_shared_buf: Buffer | BufferRegion,
ki: PrimExpr,
rk: PrimExpr | None = 0):
# Fast path for fp64: no ldmatrix support, do direct per-lane loads # Fast path for fp64: no ldmatrix support, do direct per-lane loads
if DataType(self.b_dtype).bits == 64: if DataType(self.b_dtype).bits == 64:
warp_col_tiles = self.warp_col_tiles warp_col_tiles = self.warp_col_tiles
...@@ -411,7 +396,7 @@ class TensorCoreIntrinEmitter: ...@@ -411,7 +396,7 @@ class TensorCoreIntrinEmitter:
B_base0 = B_region.region[-2].min B_base0 = B_region.region[-2].min
B_base1 = B_region.region[-1].min B_base1 = B_region.region[-1].min
B_stride_last = B_buf.shape[-1] B_stride_last = B_buf.shape[-1]
replicate_b = (self.n_dim == 16) replicate_b = self.n_dim == 16
# ldmatrix cannot be used for int8 + trans case. # ldmatrix cannot be used for int8 + trans case.
ldmatrix_available = not (DataType(b_dtype).bits != 16 and not b_transposed) ldmatrix_available = not (DataType(b_dtype).bits != 16 and not b_transposed)
...@@ -448,9 +433,7 @@ class TensorCoreIntrinEmitter: ...@@ -448,9 +433,7 @@ class TensorCoreIntrinEmitter:
) )
if ldmatrix_available: if ldmatrix_available:
B_shared_buf_elem = B_buf[B_base0 + wi, B_shared_buf_elem = B_buf[B_base0 + wi, B_base1 + wk] if b_transposed else B_buf[B_base0 + wk, B_base1 + wi]
B_base1 + wk] if b_transposed else B_buf[B_base0 + wk,
B_base1 + wi]
T.ptx_ldmatrix( T.ptx_ldmatrix(
b_dtype, b_dtype,
...@@ -469,19 +452,13 @@ class TensorCoreIntrinEmitter: ...@@ -469,19 +452,13 @@ class TensorCoreIntrinEmitter:
for j in T.serial(local_size_b): for j in T.serial(local_size_b):
mi, mk = mma_load_layout(tx, j) mi, mk = mma_load_layout(tx, j)
if b_transposed: if b_transposed:
B_local_buf[i * local_size_b + j] = B_buf[B_base0 + wi + mi, B_local_buf[i * local_size_b + j] = B_buf[B_base0 + wi + mi, B_base1 + wk + mk]
B_base1 + wk + mk]
else: else:
B_local_buf[i * local_size_b + j] = B_buf[B_base0 + wk + mk, B_local_buf[i * local_size_b + j] = B_buf[B_base0 + wk + mk, B_base1 + wi + mi]
B_base1 + wi + mi]
return _warp_ldmatrix_b(B_local_buf, B_shared_buf, ki, thread_binding, rk) return _warp_ldmatrix_b(B_local_buf, B_shared_buf, ki, thread_binding, rk)
def mma(self, def mma(self, A_local_buf: Buffer, B_local_buf: Buffer, C_local_buf: Buffer, k_inner: PrimExpr | None = 0):
A_local_buf: Buffer,
B_local_buf: Buffer,
C_local_buf: Buffer,
k_inner: PrimExpr | None = 0):
warp_rows = self.warp_rows warp_rows = self.warp_rows
warp_cols = self.warp_cols warp_cols = self.warp_cols
local_size_a = self.local_size_a local_size_a = self.local_size_a
...@@ -492,7 +469,7 @@ class TensorCoreIntrinEmitter: ...@@ -492,7 +469,7 @@ class TensorCoreIntrinEmitter:
accum_dtype = self.accum_dtype accum_dtype = self.accum_dtype
accum_dtype_abbrv = self.accum_dtype_abbrv accum_dtype_abbrv = self.accum_dtype_abbrv
mma_prefix = self.mma_prefix mma_prefix = self.mma_prefix
replicate_b = (self.n_dim == 16) replicate_b = self.n_dim == 16
a_is_fragment = is_fragment(A_local_buf) a_is_fragment = is_fragment(A_local_buf)
b_is_fragment = is_fragment(B_local_buf) b_is_fragment = is_fragment(B_local_buf)
...@@ -532,8 +509,7 @@ class TensorCoreIntrinEmitter: ...@@ -532,8 +509,7 @@ class TensorCoreIntrinEmitter:
B_local_buf.data, B_local_buf.data,
b_local_stride + j * local_size_b + lift(local_size_b) // 2, b_local_stride + j * local_size_b + lift(local_size_b) // 2,
C_local_buf.data, C_local_buf.data,
i * warp_cols * local_size_out + j * local_size_out + i * warp_cols * local_size_out + j * local_size_out + lift(local_size_out) // 2,
lift(local_size_out) // 2,
T.bool(False), # saturate T.bool(False), # saturate
) )
...@@ -568,14 +544,13 @@ class TensorCoreIntrinEmitter: ...@@ -568,14 +544,13 @@ class TensorCoreIntrinEmitter:
local_id = local_id_o * 2 + local_id_i local_id = local_id_o * 2 + local_id_i
row, col = T.meta_var(mma_store_index_map(tx, local_id)) row, col = T.meta_var(mma_store_index_map(tx, local_id))
if C_buf_dims == 2: if C_buf_dims == 2:
C_buf[(warp_m * warp_rows + i) * M_DIM + row, C_buf[(warp_m * warp_rows + i) * M_DIM + row, (warp_n * warp_cols + j) * n_dim + col] = C_local_buf[
(warp_n * warp_cols + j) * n_dim + i * (warp_cols * local_size_out) + j * local_size_out + local_id
col] = C_local_buf[i * (warp_cols * local_size_out) + ]
j * local_size_out + local_id]
else: else:
C_buf[warp_m * warp_rows + i, warp_n * warp_cols + j, row, C_buf[warp_m * warp_rows + i, warp_n * warp_cols + j, row, col] = C_local_buf[
col] = C_local_buf[i * (warp_cols * local_size_out) + i * (warp_cols * local_size_out) + j * local_size_out + local_id
j * local_size_out + local_id] ]
@T.macro @T.macro
def _warp_stmatrix_global(C_local_buf, C_buf, thread_binding): def _warp_stmatrix_global(C_local_buf, C_buf, thread_binding):
...@@ -588,15 +563,15 @@ class TensorCoreIntrinEmitter: ...@@ -588,15 +563,15 @@ class TensorCoreIntrinEmitter:
C_buf[ C_buf[
(pid_m * BLOCK_M + warp_m * warp_rows + i) * M_DIM + row, (pid_m * BLOCK_M + warp_m * warp_rows + i) * M_DIM + row,
(pid_n * BLOCK_N + warp_n * warp_cols + j) * n_dim + col, (pid_n * BLOCK_N + warp_n * warp_cols + j) * n_dim + col,
] = C_local_buf[i * warp_cols * local_size_out + j * local_size_out + ] = C_local_buf[i * warp_cols * local_size_out + j * local_size_out + local_id]
local_id]
return (_warp_stmatrix_global(C_local_buf, C_buf, thread_binding) return (
if is_global else _warp_stmatrix_shared(C_local_buf, C_buf, thread_binding)) _warp_stmatrix_global(C_local_buf, C_buf, thread_binding)
if is_global
else _warp_stmatrix_shared(C_local_buf, C_buf, thread_binding)
)
def make_mma_load_layout(self, def make_mma_load_layout(self, local_buf: Buffer, matrix: Literal["A", "B"] = "A") -> T.Fragment:
local_buf: Buffer,
matrix: Literal["A", "B"] = "A") -> T.Fragment:
""" """
Create a layout function for storing MMA results into a fragment buffer. Create a layout function for storing MMA results into a fragment buffer.
This layout is used in conjunction with `inverse_mma_store_layout` to This layout is used in conjunction with `inverse_mma_store_layout` to
...@@ -619,6 +594,7 @@ class TensorCoreIntrinEmitter: ...@@ -619,6 +594,7 @@ class TensorCoreIntrinEmitter:
If `local_buf` is not detected to be a fragment buffer. If `local_buf` is not detected to be a fragment buffer.
""" """
from tilelang.utils import is_fragment from tilelang.utils import is_fragment
assert matrix in ["A", "B"], "matrix should be either A or B" assert matrix in ["A", "B"], "matrix should be either A or B"
matrix_is_a: bool = matrix == "A" matrix_is_a: bool = matrix == "A"
matrix_is_b: bool = matrix == "B" matrix_is_b: bool = matrix == "B"
...@@ -655,11 +631,9 @@ class TensorCoreIntrinEmitter: ...@@ -655,11 +631,9 @@ class TensorCoreIntrinEmitter:
# so the b matrix expected a transposed basic layout # so the b matrix expected a transposed basic layout
transform_func: Callable = None transform_func: Callable = None
if matrix_is_a: if matrix_is_a:
transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a( transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a(j, i)
j, i)
elif matrix_is_b: elif matrix_is_b:
transform_func = transform_func_sr_b if is_sr_axis_order else lambda i, j: transform_func_sr_b( transform_func = transform_func_sr_b if is_sr_axis_order else lambda i, j: transform_func_sr_b(j, i)
j, i)
else: else:
raise ValueError(f"Unsupported matrix {matrix}") raise ValueError(f"Unsupported matrix {matrix}")
...@@ -706,31 +680,19 @@ class TensorCoreIntrinEmitter: ...@@ -706,31 +680,19 @@ class TensorCoreIntrinEmitter:
replicate = block_col_warps if matrix_is_a else block_row_warps replicate = block_col_warps if matrix_is_a else block_row_warps
if is_sr_axis_order: if is_sr_axis_order:
warp_fragment = base_fragment.repeat([warp_s, warp_r], warp_fragment = base_fragment.repeat([warp_s, warp_r], repeat_on_thread=False, lower_dim_first=False)
repeat_on_thread=False,
lower_dim_first=False)
if matrix_is_a: if matrix_is_a:
block_fragment = warp_fragment.repeat([block_s, 1], block_fragment = warp_fragment.repeat([block_s, 1], repeat_on_thread=True, lower_dim_first=True).replicate(replicate)
repeat_on_thread=True,
lower_dim_first=True).replicate(replicate)
elif matrix_is_b: elif matrix_is_b:
block_fragment = warp_fragment.replicate(replicate).repeat([block_s, 1], block_fragment = warp_fragment.replicate(replicate).repeat([block_s, 1], repeat_on_thread=True, lower_dim_first=True)
repeat_on_thread=True,
lower_dim_first=True)
else: else:
raise ValueError(f"Unsupported matrix type {matrix}") raise ValueError(f"Unsupported matrix type {matrix}")
else: else:
warp_fragment = base_fragment.repeat([warp_r, warp_s], warp_fragment = base_fragment.repeat([warp_r, warp_s], repeat_on_thread=False, lower_dim_first=True)
repeat_on_thread=False,
lower_dim_first=True)
if matrix_is_a: if matrix_is_a:
block_fragment = warp_fragment.repeat([1, block_s], block_fragment = warp_fragment.repeat([1, block_s], repeat_on_thread=True, lower_dim_first=True).replicate(replicate)
repeat_on_thread=True,
lower_dim_first=True).replicate(replicate)
elif matrix_is_b: elif matrix_is_b:
block_fragment = warp_fragment.replicate(replicate).repeat([1, block_s], block_fragment = warp_fragment.replicate(replicate).repeat([1, block_s], repeat_on_thread=True, lower_dim_first=True)
repeat_on_thread=True,
lower_dim_first=True)
else: else:
raise ValueError(f"Unsupported matrix type {matrix}") raise ValueError(f"Unsupported matrix type {matrix}")
...@@ -761,8 +723,7 @@ class TensorCoreIntrinEmitter: ...@@ -761,8 +723,7 @@ class TensorCoreIntrinEmitter:
from tilelang.utils import is_fragment from tilelang.utils import is_fragment
shape = local_buf.shape shape = local_buf.shape
assert is_fragment( assert is_fragment(local_buf), f"local_buf {local_buf} must be a fragment, but got {local_buf.scope()}"
local_buf), f"local_buf {local_buf} must be a fragment, but got {local_buf.scope()}"
inverse_mma_store_layout = self.get_store_index_map(inverse=True) inverse_mma_store_layout = self.get_store_index_map(inverse=True)
micro_size_x, micro_size_y = self.micro_size_x, self.micro_size_y micro_size_x, micro_size_y = self.micro_size_x, self.micro_size_y
...@@ -954,10 +915,12 @@ class TensorCoreIntrinEmitterWithLadderTransform(TensorCoreIntrinEmitter): ...@@ -954,10 +915,12 @@ class TensorCoreIntrinEmitterWithLadderTransform(TensorCoreIntrinEmitter):
".b16", ".b16",
A_local_buf.data, A_local_buf.data,
i * local_size_a, i * local_size_a,
T.address_of(A_shared_buf[ T.address_of(
warp_m * warp_row_tiles + i * micro_size_x, A_shared_buf[
rk * chunk + ki * micro_size_k, warp_m * warp_row_tiles + i * micro_size_x,
]), rk * chunk + ki * micro_size_k,
]
),
get_ldmatrix_offset("A", tx, 0, stride, a_dtype, a_transposed), get_ldmatrix_offset("A", tx, 0, stride, a_dtype, a_transposed),
) )
elif transform_kind_a == TransformKind.InterWarpTransform: elif transform_kind_a == TransformKind.InterWarpTransform:
...@@ -1019,10 +982,8 @@ class TensorCoreIntrinEmitterWithLadderTransform(TensorCoreIntrinEmitter): ...@@ -1019,10 +982,8 @@ class TensorCoreIntrinEmitterWithLadderTransform(TensorCoreIntrinEmitter):
warp_m * warp_rows + j, warp_m * warp_rows + j,
rk * (chunk // micro_size_k) + ki, rk * (chunk // micro_size_k) + ki,
) )
rii, rjj = (tx * local_size_a + rii, rjj = (tx * local_size_a + local_id) // micro_size_k, (tx * local_size_a + local_id) % (micro_size_k)
local_id) // micro_size_k, (tx * local_size_a + local_id) % ( A_local_buf[j * local_size_a + local_id] = A_shared_buf[ri, rj, rii, rjj]
micro_size_k)
A_local_buf[j * local_size_a + local_id] = (A_shared_buf[ri, rj, rii, rjj])
else: else:
raise ValueError("Unsupported TransformKind for Input A") raise ValueError("Unsupported TransformKind for Input A")
...@@ -1131,12 +1092,11 @@ class TensorCoreIntrinEmitterWithLadderTransform(TensorCoreIntrinEmitter): ...@@ -1131,12 +1092,11 @@ class TensorCoreIntrinEmitterWithLadderTransform(TensorCoreIntrinEmitter):
warp_n * warp_cols + j, warp_n * warp_cols + j,
rk * (chunk // micro_size_k) + ki, rk * (chunk // micro_size_k) + ki,
) )
rii, rjj = (tx * local_size_dequantize + rii, rjj = (
local_id) // (micro_size_k // num_elems_per_byte), ( (tx * local_size_dequantize + local_id) // (micro_size_k // num_elems_per_byte),
tx * local_size_dequantize + local_id) % ( (tx * local_size_dequantize + local_id) % (micro_size_k // num_elems_per_byte),
micro_size_k // num_elems_per_byte) )
B_local_buf[j * local_size_dequantize + local_id] = ( B_local_buf[j * local_size_dequantize + local_id] = B_shared_buf[ri, rj, rii, rjj]
B_shared_buf[ri, rj, rii, rjj])
else: else:
raise ValueError("Unsupported TransformKind for Input B") raise ValueError("Unsupported TransformKind for Input B")
...@@ -1195,7 +1155,6 @@ class TensorCoreIntrinEmitterWithLadderTransform(TensorCoreIntrinEmitter): ...@@ -1195,7 +1155,6 @@ class TensorCoreIntrinEmitterWithLadderTransform(TensorCoreIntrinEmitter):
class INT4TensorCoreIntrinEmitter(TensorCoreIntrinEmitter): class INT4TensorCoreIntrinEmitter(TensorCoreIntrinEmitter):
def mma(self, A_local_buf, B_local_buf, C_local_buf): def mma(self, A_local_buf, B_local_buf, C_local_buf):
warp_rows = self.warp_rows warp_rows = self.warp_rows
warp_cols = self.warp_cols warp_cols = self.warp_cols
...@@ -1298,9 +1257,7 @@ class INT4TensorCoreIntrinEmitter(TensorCoreIntrinEmitter): ...@@ -1298,9 +1257,7 @@ class INT4TensorCoreIntrinEmitter(TensorCoreIntrinEmitter):
class INT4TensorCoreIntrinEmitterWithLadderTransform(TensorCoreIntrinEmitterWithLadderTransform): class INT4TensorCoreIntrinEmitterWithLadderTransform(TensorCoreIntrinEmitterWithLadderTransform):
def mma(self, A_local_buf, B_local_buf, C_local_buf): def mma(self, A_local_buf, B_local_buf, C_local_buf):
warp_rows = self.warp_rows warp_rows = self.warp_rows
warp_cols = self.warp_cols warp_cols = self.warp_cols
local_size_a = self.local_size_a local_size_a = self.local_size_a
......
...@@ -17,10 +17,8 @@ def shared_16x4_to_mma_b_32x4_layout_trans(row, col, rep): ...@@ -17,10 +17,8 @@ def shared_16x4_to_mma_b_32x4_layout_trans(row, col, rep):
def mma_32x8_to_shared_16x16_layout_fp32(thread_id, local_id): def mma_32x8_to_shared_16x16_layout_fp32(thread_id, local_id):
row = (thread_id % 2) + ( row = (thread_id % 2) + ((local_id // 2 % 2) * 2) + 4 * (thread_id // 16) + (thread_id % 16 // 4) % 2 * 8
(local_id // 2 % 2) * 2) + 4 * (thread_id // 16) + (thread_id % 16 // 4) % 2 * 8 col = (thread_id % 4 // 2) * 2 + (thread_id % 16 // 8) * 4 + (local_id % 2) + (local_id // 4) * 8
col = (thread_id % 4 // 2) * 2 + (thread_id % 16 // 8) * 4 + (local_id %
2) + (local_id // 4) * 8
return row, col return row, col
...@@ -31,7 +29,7 @@ def mma_32x8_to_shared_16x16_layout_fp16(thread_id, local_id): ...@@ -31,7 +29,7 @@ def mma_32x8_to_shared_16x16_layout_fp16(thread_id, local_id):
def mma_load_a_32x4_to_shared_16x4_layout(thread_id, local_id): def mma_load_a_32x4_to_shared_16x4_layout(thread_id, local_id):
row = (thread_id % 4) + (4 * (((thread_id // 16 + thread_id % 16 // 4 * 2)) % 4)) row = (thread_id % 4) + (4 * ((thread_id // 16 + thread_id % 16 // 4 * 2) % 4))
col = local_id col = local_id
return row, col return row, col
......
...@@ -147,18 +147,15 @@ class TensorCoreIntrinEmitter: ...@@ -147,18 +147,15 @@ class TensorCoreIntrinEmitter:
def get_store_index_map(self, inverse: bool = False) -> IndexMap: def get_store_index_map(self, inverse: bool = False) -> IndexMap:
warp_size, local_size_c = self.WARP_SIZE, self.local_size_out warp_size, local_size_c = self.WARP_SIZE, self.local_size_out
index_map = IndexMap.from_func( index_map = IndexMap.from_func(
mma_32x8_to_shared_16x16_layout_fp32 mma_32x8_to_shared_16x16_layout_fp32 if self.accum_dtype == "float32" else mma_32x8_to_shared_16x16_layout_fp16,
if self.accum_dtype == "float32" else mma_32x8_to_shared_16x16_layout_fp16, index_dtype="int32",
index_dtype="int32") )
if not inverse: if not inverse:
return index_map return index_map
inverse_index_map = index_map.inverse([warp_size, local_size_c]) inverse_index_map = index_map.inverse([warp_size, local_size_c])
return inverse_index_map return inverse_index_map
def extract_thread_binding( def extract_thread_binding(self, thread_id: PrimExpr, is_m_first: bool | None = None) -> tuple[PrimExpr, PrimExpr, PrimExpr]:
self,
thread_id: PrimExpr,
is_m_first: bool | None = None) -> tuple[PrimExpr, PrimExpr, PrimExpr]:
""" """
is_m_first: True if the thread binding is in the form of (tx, warp_n, warp_m) is_m_first: True if the thread binding is in the form of (tx, warp_n, warp_m)
which represents [warp_size, block_row_warps (split n), block_col_warps (split m)] which represents [warp_size, block_row_warps (split n), block_col_warps (split m)]
...@@ -187,11 +184,7 @@ class TensorCoreIntrinEmitter: ...@@ -187,11 +184,7 @@ class TensorCoreIntrinEmitter:
) )
return lane_id, warp_n, warp_m return lane_id, warp_n, warp_m
def ldmatrix_a(self, def ldmatrix_a(self, A_local_buf: Buffer, A_shared_buf: Buffer | BufferRegion, ki: PrimExpr, rk: PrimExpr | None = 0):
A_local_buf: Buffer,
A_shared_buf: Buffer | BufferRegion,
ki: PrimExpr,
rk: PrimExpr | None = 0):
warp_row_tiles = self.warp_row_tiles warp_row_tiles = self.warp_row_tiles
warp_rows = self.warp_rows warp_rows = self.warp_rows
chunk = self.chunk chunk = self.chunk
...@@ -231,11 +224,7 @@ class TensorCoreIntrinEmitter: ...@@ -231,11 +224,7 @@ class TensorCoreIntrinEmitter:
return _warp_ldmatrix_a(A_local_buf, A_region, ki, thread_binding, rk) return _warp_ldmatrix_a(A_local_buf, A_region, ki, thread_binding, rk)
def ldmatrix_b(self, def ldmatrix_b(self, B_local_buf: Buffer, B_shared_buf: Buffer | BufferRegion, ki: PrimExpr, rk: PrimExpr | None = 0):
B_local_buf: Buffer,
B_shared_buf: Buffer | BufferRegion,
ki: PrimExpr,
rk: PrimExpr | None = 0):
warp_col_tiles = self.warp_col_tiles warp_col_tiles = self.warp_col_tiles
warp_cols = self.warp_cols warp_cols = self.warp_cols
chunk = self.chunk chunk = self.chunk
...@@ -274,20 +263,14 @@ class TensorCoreIntrinEmitter: ...@@ -274,20 +263,14 @@ class TensorCoreIntrinEmitter:
for j in T.vectorized(local_size_b): for j in T.vectorized(local_size_b):
if b_transposed: if b_transposed:
mi, mk = mma_load_layout(tx, j) mi, mk = mma_load_layout(tx, j)
B_local_buf[i * local_size_b + j] = B_buf[B_base0 + wi + mi, B_local_buf[i * local_size_b + j] = B_buf[B_base0 + wi + mi, B_base1 + wk + mk]
B_base1 + wk + mk]
else: else:
mk, mi = mma_load_layout(tx, j) mk, mi = mma_load_layout(tx, j)
B_local_buf[i * local_size_b + j] = B_buf[B_base0 + wk + mk, B_local_buf[i * local_size_b + j] = B_buf[B_base0 + wk + mk, B_base1 + wi + mi]
B_base1 + wi + mi]
return _warp_ldmatrix_b(B_local_buf, B_region, ki, thread_binding, rk) return _warp_ldmatrix_b(B_local_buf, B_region, ki, thread_binding, rk)
def mma(self, def mma(self, A_local_buf: Buffer, B_local_buf: Buffer, C_local_buf: Buffer, k_inner: PrimExpr | None = 0):
A_local_buf: Buffer,
B_local_buf: Buffer,
C_local_buf: Buffer,
k_inner: PrimExpr | None = 0):
warp_rows = self.warp_rows warp_rows = self.warp_rows
warp_cols = self.warp_cols warp_cols = self.warp_cols
local_size_a = self.local_size_a local_size_a = self.local_size_a
...@@ -326,9 +309,7 @@ class TensorCoreIntrinEmitter: ...@@ -326,9 +309,7 @@ class TensorCoreIntrinEmitter:
return _warp_mma(A_local_buf, B_local_buf, C_local_buf) return _warp_mma(A_local_buf, B_local_buf, C_local_buf)
def make_mma_load_layout(self, def make_mma_load_layout(self, local_buf: Buffer, matrix: Literal["A", "B"] = "A") -> T.Fragment:
local_buf: Buffer,
matrix: Literal["A", "B"] = "A") -> T.Fragment:
""" """
Create a layout function for storing MMA results into a fragment buffer. Create a layout function for storing MMA results into a fragment buffer.
This layout is used in conjunction with `inverse_mma_store_layout` to This layout is used in conjunction with `inverse_mma_store_layout` to
...@@ -351,6 +332,7 @@ class TensorCoreIntrinEmitter: ...@@ -351,6 +332,7 @@ class TensorCoreIntrinEmitter:
If `local_buf` is not detected to be a fragment buffer. If `local_buf` is not detected to be a fragment buffer.
""" """
from tilelang.utils import is_fragment from tilelang.utils import is_fragment
assert matrix in ["A", "B"], "matrix should be either A or B" assert matrix in ["A", "B"], "matrix should be either A or B"
matrix_is_a: bool = matrix == "A" matrix_is_a: bool = matrix == "A"
matrix_is_b: bool = matrix == "B" matrix_is_b: bool = matrix == "B"
...@@ -383,11 +365,9 @@ class TensorCoreIntrinEmitter: ...@@ -383,11 +365,9 @@ class TensorCoreIntrinEmitter:
# so the b matrix expected a transposed basic layout # so the b matrix expected a transposed basic layout
transform_func: Callable = None transform_func: Callable = None
if matrix_is_a: if matrix_is_a:
transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a( transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a(j, i)
j, i)
elif matrix_is_b: elif matrix_is_b:
transform_func = transform_func_sr_b if is_sr_axis_order else lambda i, j: transform_func_rs_b( transform_func = transform_func_sr_b if is_sr_axis_order else lambda i, j: transform_func_rs_b(i, j)
i, j)
else: else:
raise ValueError(f"Unsupported matrix {matrix}") raise ValueError(f"Unsupported matrix {matrix}")
...@@ -413,9 +393,8 @@ class TensorCoreIntrinEmitter: ...@@ -413,9 +393,8 @@ class TensorCoreIntrinEmitter:
return lane_id, local_id return lane_id, local_id
base_fragment = T.Fragment( base_fragment = T.Fragment(
[micro_size_s, micro_size_r] if is_sr_axis_order else [micro_size_r, micro_size_s], [micro_size_s, micro_size_r] if is_sr_axis_order else [micro_size_r, micro_size_s], forward_fn=forward, replicate=2
forward_fn=forward, )
replicate=2)
warp_rows, warp_cols = self.warp_rows, self.warp_cols warp_rows, warp_cols = self.warp_rows, self.warp_cols
chunk = self.chunk chunk = self.chunk
...@@ -426,31 +405,19 @@ class TensorCoreIntrinEmitter: ...@@ -426,31 +405,19 @@ class TensorCoreIntrinEmitter:
replicate = block_col_warps if matrix_is_a else block_row_warps replicate = block_col_warps if matrix_is_a else block_row_warps
if is_sr_axis_order: if is_sr_axis_order:
warp_fragment = base_fragment.repeat([warp_s, warp_r], warp_fragment = base_fragment.repeat([warp_s, warp_r], repeat_on_thread=False, lower_dim_first=False)
repeat_on_thread=False,
lower_dim_first=False)
if matrix_is_a: if matrix_is_a:
block_fragment = warp_fragment.repeat([block_s, 1], block_fragment = warp_fragment.repeat([block_s, 1], repeat_on_thread=True, lower_dim_first=True).replicate(replicate)
repeat_on_thread=True,
lower_dim_first=True).replicate(replicate)
elif matrix_is_b: elif matrix_is_b:
block_fragment = warp_fragment.replicate(replicate).repeat([block_s, 1], block_fragment = warp_fragment.replicate(replicate).repeat([block_s, 1], repeat_on_thread=True, lower_dim_first=True)
repeat_on_thread=True,
lower_dim_first=True)
else: else:
raise ValueError(f"Unsupported matrix type {matrix}") raise ValueError(f"Unsupported matrix type {matrix}")
else: else:
warp_fragment = base_fragment.repeat([warp_r, warp_s], warp_fragment = base_fragment.repeat([warp_r, warp_s], repeat_on_thread=False, lower_dim_first=True)
repeat_on_thread=False,
lower_dim_first=True)
if matrix_is_a: if matrix_is_a:
block_fragment = warp_fragment.repeat([1, block_s], block_fragment = warp_fragment.repeat([1, block_s], repeat_on_thread=True, lower_dim_first=True).replicate(replicate)
repeat_on_thread=True,
lower_dim_first=True).replicate(replicate)
elif matrix_is_b: elif matrix_is_b:
block_fragment = warp_fragment.replicate(replicate).repeat([1, block_s], block_fragment = warp_fragment.replicate(replicate).repeat([1, block_s], repeat_on_thread=True, lower_dim_first=True)
repeat_on_thread=True,
lower_dim_first=True)
else: else:
raise ValueError(f"Unsupported matrix type {matrix}") raise ValueError(f"Unsupported matrix type {matrix}")
......
...@@ -72,56 +72,47 @@ def get_logical_id_32bit(thread_id: int) -> int: ...@@ -72,56 +72,47 @@ def get_logical_id_32bit(thread_id: int) -> int:
return (thread_id // 4) * 2 + (thread_id % 4) % 2 return (thread_id // 4) * 2 + (thread_id % 4) % 2
def metadata_8bit_load_32x4_to_shared_16x4_layout_32bit(thread_id: int, def metadata_8bit_load_32x4_to_shared_16x4_layout_32bit(thread_id: int, local_id: int) -> tuple[int, int]:
local_id: int) -> tuple[int, int]:
logical_id = get_logical_id_32bit(thread_id) logical_id = get_logical_id_32bit(thread_id)
row = logical_id // 4 + local_id * 8 row = logical_id // 4 + local_id * 8
col = logical_id % 4 col = logical_id % 4
return row, col return row, col
def metadata_16bit_load_32x2_to_shared_16x2_layout_32bit(thread_id: int, def metadata_16bit_load_32x2_to_shared_16x2_layout_32bit(thread_id: int, local_id: int) -> tuple[int, int]:
local_id: int) -> tuple[int, int]:
logical_id = get_logical_id_32bit(thread_id) logical_id = get_logical_id_32bit(thread_id)
row = logical_id // 2 + local_id * 8 row = logical_id // 2 + local_id * 8
col = logical_id % 2 col = logical_id % 2
return row, col return row, col
def metadata_8bit_load_32x4_to_shared_16x4_layout_16bit(thread_id: int, def metadata_8bit_load_32x4_to_shared_16x4_layout_16bit(thread_id: int, local_id: int) -> tuple[int, int]:
local_id: int) -> tuple[int, int]: return metadata_8bit_load_32x4_to_shared_16x4_layout_32bit(thread_id, local_id) # same mapping for 16bit and 32bit
return metadata_8bit_load_32x4_to_shared_16x4_layout_32bit(
thread_id, local_id) # same mapping for 16bit and 32bit
def metadata_16bit_load_32x2_to_shared_16x2_layout_16bit(thread_id: int, def metadata_16bit_load_32x2_to_shared_16x2_layout_16bit(thread_id: int, local_id: int) -> tuple[int, int]:
local_id: int) -> tuple[int, int]: return metadata_16bit_load_32x2_to_shared_16x2_layout_32bit(thread_id, local_id) # same mapping for 16bit and 32bit
return metadata_16bit_load_32x2_to_shared_16x2_layout_32bit(
thread_id, local_id) # same mapping for 16bit and 32bit
def get_logical_id_8bit(thread_id: int) -> int: def get_logical_id_8bit(thread_id: int) -> int:
return thread_id return thread_id
def metadata_8bit_load_32x4_to_shared_16x4_layout_8bit(thread_id: int, def metadata_8bit_load_32x4_to_shared_16x4_layout_8bit(thread_id: int, local_id: int) -> tuple[int, int]:
local_id: int) -> tuple[int, int]:
logical_id = get_logical_id_8bit(thread_id) logical_id = get_logical_id_8bit(thread_id)
row = logical_id // 2 + local_id * 8 row = logical_id // 2 + local_id * 8
col = (logical_id % 4) // 2 * 4 + local_id col = (logical_id % 4) // 2 * 4 + local_id
return row, col return row, col
def metadata_16bit_load_32x2_to_shared_16x4_layout_8bit(thread_id: int, def metadata_16bit_load_32x2_to_shared_16x4_layout_8bit(thread_id: int, local_id: int) -> tuple[int, int]:
local_id: int) -> tuple[int, int]:
logical_id = get_logical_id_8bit(thread_id) logical_id = get_logical_id_8bit(thread_id)
row = logical_id // 2 + local_id * 8 row = logical_id // 2 + local_id * 8
col = (logical_id % 4) // 2 * 2 + local_id col = (logical_id % 4) // 2 * 2 + local_id
return row, col return row, col
def metadata_32bit_load_32x1_to_shared_16x2_layout_8bit(thread_id: int, def metadata_32bit_load_32x1_to_shared_16x2_layout_8bit(thread_id: int, local_id: int) -> tuple[int, int]:
local_id: int) -> tuple[int, int]:
# local_id is always 0 # local_id is always 0
logical_id = get_logical_id_8bit(thread_id) logical_id = get_logical_id_8bit(thread_id)
row = logical_id // 4 + (logical_id % 2) * 8 row = logical_id // 4 + (logical_id % 2) * 8
......
...@@ -190,8 +190,7 @@ class SparseTensorCoreIntrinEmitter: ...@@ -190,8 +190,7 @@ class SparseTensorCoreIntrinEmitter:
def _initialize_local_size(self, m_dim=16, n_dim=16, k_dim=16, warp_size=32): def _initialize_local_size(self, m_dim=16, n_dim=16, k_dim=16, warp_size=32):
self.local_size_a = (m_dim * k_dim) // warp_size // self.SPARSE_FACTOR self.local_size_a = (m_dim * k_dim) // warp_size // self.SPARSE_FACTOR
self.local_size_e = ( self.local_size_e = (m_dim * k_dim) // self.e_factor // warp_size * self.E_REPLICATE_FACTOR[self.a_dtype]
m_dim * k_dim) // self.e_factor // warp_size * self.E_REPLICATE_FACTOR[self.a_dtype]
self.local_size_b = (n_dim * k_dim) // warp_size self.local_size_b = (n_dim * k_dim) // warp_size
self.local_size_out = (m_dim * n_dim) // warp_size self.local_size_out = (m_dim * n_dim) // warp_size
...@@ -257,10 +256,7 @@ class SparseTensorCoreIntrinEmitter: ...@@ -257,10 +256,7 @@ class SparseTensorCoreIntrinEmitter:
inverse_index_map = index_map.inverse([warp_size, local_size_c]) inverse_index_map = index_map.inverse([warp_size, local_size_c])
return inverse_index_map return inverse_index_map
def extract_thread_binding( def extract_thread_binding(self, thread_id: PrimExpr, is_m_first: bool | None = None) -> tuple[PrimExpr, PrimExpr, PrimExpr]:
self,
thread_id: PrimExpr,
is_m_first: bool | None = None) -> tuple[PrimExpr, PrimExpr, PrimExpr]:
""" """
is_m_first: True if the thread binding is in the form of (tx, warp_n, warp_m) is_m_first: True if the thread binding is in the form of (tx, warp_n, warp_m)
which represents [warp_size, block_row_warps (split n), block_col_warps (split m)] which represents [warp_size, block_row_warps (split n), block_col_warps (split m)]
...@@ -330,8 +326,7 @@ class SparseTensorCoreIntrinEmitter: ...@@ -330,8 +326,7 @@ class SparseTensorCoreIntrinEmitter:
for i in T.serial(warp_rows): for i in T.serial(warp_rows):
# Assign A_shared_buf_elem # Assign A_shared_buf_elem
wi, wk = warp_m * warp_row_tiles + i * micro_size_x, ( wi, wk = warp_m * warp_row_tiles + i * micro_size_x, (rk * warp_k + ki * micro_size_k) // self.SPARSE_FACTOR
rk * warp_k + ki * micro_size_k) // self.SPARSE_FACTOR
A_shared_buf_elem = A_shared_buf[wk, wi] if a_transposed else A_shared_buf[wi, wk] A_shared_buf_elem = A_shared_buf[wk, wi] if a_transposed else A_shared_buf[wi, wk]
if ldmatrix_available: if ldmatrix_available:
...@@ -348,10 +343,9 @@ class SparseTensorCoreIntrinEmitter: ...@@ -348,10 +343,9 @@ class SparseTensorCoreIntrinEmitter:
else: else:
for j in T.serial(local_size_a): for j in T.serial(local_size_a):
mi, mk = mma_load_layout(tx, j) mi, mk = mma_load_layout(tx, j)
A_local_buf[i * local_size_a + A_local_buf[i * local_size_a + j] = (
j] = A_shared_buf[wk + mk, wi + A_shared_buf[wk + mk, wi + mi] if a_transposed else A_shared_buf[wi + mi, wk + mk]
mi] if a_transposed else A_shared_buf[wi + mi, )
wk + mk]
return _warp_ldmatrix_a(A_local_buf, A_shared_buf, ki, thread_binding, rk) return _warp_ldmatrix_a(A_local_buf, A_shared_buf, ki, thread_binding, rk)
...@@ -412,14 +406,10 @@ class SparseTensorCoreIntrinEmitter: ...@@ -412,14 +406,10 @@ class SparseTensorCoreIntrinEmitter:
tx, _, warp_m = self.extract_thread_binding(thread_binding) tx, _, warp_m = self.extract_thread_binding(thread_binding)
for i in T.serial(warp_rows): for i in T.serial(warp_rows):
# Assign E_shared_buf_elem # Assign E_shared_buf_elem
wi, wk = warp_m * warp_row_tiles + i * micro_size_x, ( wi, wk = warp_m * warp_row_tiles + i * micro_size_x, (rk * warp_k + ki * micro_size_k) // self.e_factor
rk * warp_k + ki * micro_size_k) // self.e_factor
for j in T.serial(local_size_e): for j in T.serial(local_size_e):
mi, mk = mma_load_layout(tx, j) mi, mk = mma_load_layout(tx, j)
E_local_buf[i * local_size_e + E_local_buf[i * local_size_e + j] = E_shared_buf[wk + mk, wi + mi] if trans else E_shared_buf[wi + mi, wk + mk]
j] = E_shared_buf[wk + mk,
wi + mi] if trans else E_shared_buf[wi + mi,
wk + mk]
return _warp_ldmatrix_e(E_local_buf, E_shared_buf, ki, thread_binding, rk) return _warp_ldmatrix_e(E_local_buf, E_shared_buf, ki, thread_binding, rk)
...@@ -433,7 +423,7 @@ class SparseTensorCoreIntrinEmitter: ...@@ -433,7 +423,7 @@ class SparseTensorCoreIntrinEmitter:
b_dtype = self.b_dtype b_dtype = self.b_dtype
b_transposed = self.b_transposed b_transposed = self.b_transposed
thread_binding = self.get_thread_binding() thread_binding = self.get_thread_binding()
replicate_b = (self.n_dim == 16) replicate_b = self.n_dim == 16
# ldmatrix cannot be used for int8 + trans case. # ldmatrix cannot be used for int8 + trans case.
ldmatrix_available = not (DataType(b_dtype).bits != 16 and not b_transposed) ldmatrix_available = not (DataType(b_dtype).bits != 16 and not b_transposed)
...@@ -470,8 +460,7 @@ class SparseTensorCoreIntrinEmitter: ...@@ -470,8 +460,7 @@ class SparseTensorCoreIntrinEmitter:
) )
if ldmatrix_available: if ldmatrix_available:
B_shared_buf_elem = B_shared_buf[wi, wk] if b_transposed else B_shared_buf[wk, B_shared_buf_elem = B_shared_buf[wi, wk] if b_transposed else B_shared_buf[wk, wi]
wi]
if replicate_b: if replicate_b:
T.ptx_ldmatrix( T.ptx_ldmatrix(
...@@ -493,9 +482,7 @@ class SparseTensorCoreIntrinEmitter: ...@@ -493,9 +482,7 @@ class SparseTensorCoreIntrinEmitter:
B_local_buf.data, B_local_buf.data,
i * local_size_b + lift(local_size_b) // 2, i * local_size_b + lift(local_size_b) // 2,
T.address_of(B_shared_buf_elem), T.address_of(B_shared_buf_elem),
get_ldmatrix_offset_b("B", tx, get_ldmatrix_offset_b("B", tx, lift(local_size_b) // 2, stride, b_dtype, b_transposed),
lift(local_size_b) // 2, stride, b_dtype,
b_transposed),
) )
else: else:
T.ptx_ldmatrix( T.ptx_ldmatrix(
...@@ -514,19 +501,13 @@ class SparseTensorCoreIntrinEmitter: ...@@ -514,19 +501,13 @@ class SparseTensorCoreIntrinEmitter:
# must be transposed. # must be transposed.
for j in T.serial(local_size_b): for j in T.serial(local_size_b):
mi, mk = mma_load_layout(tx, j) mi, mk = mma_load_layout(tx, j)
B_local_buf[i * local_size_b + B_local_buf[i * local_size_b + j] = (
j] = B_shared_buf[wi + mi, wk + B_shared_buf[wi + mi, wk + mk] if b_transposed else B_shared_buf[wk + mk, wi + mi]
mk] if b_transposed else B_shared_buf[wk + mk, )
wi + mi]
return _warp_ldmatrix_b(B_local_buf, B_shared_buf, ki, thread_binding, rk) return _warp_ldmatrix_b(B_local_buf, B_shared_buf, ki, thread_binding, rk)
def mma_sp(self, def mma_sp(self, A_local_buf: Buffer, E_local_buf: Buffer, B_local_buf: Buffer, C_local_buf: Buffer, k_inner: PrimExpr = 0):
A_local_buf: Buffer,
E_local_buf: Buffer,
B_local_buf: Buffer,
C_local_buf: Buffer,
k_inner: PrimExpr = 0):
warp_rows = self.warp_rows warp_rows = self.warp_rows
warp_cols = self.warp_cols warp_cols = self.warp_cols
local_size_a = self.local_size_a local_size_a = self.local_size_a
...@@ -538,7 +519,7 @@ class SparseTensorCoreIntrinEmitter: ...@@ -538,7 +519,7 @@ class SparseTensorCoreIntrinEmitter:
accum_dtype = self.accum_dtype accum_dtype = self.accum_dtype
accum_dtype_abbrv = self.accum_dtype_abbrv accum_dtype_abbrv = self.accum_dtype_abbrv
mma_prefix = self.mma_prefix mma_prefix = self.mma_prefix
replicate_b = (self.n_dim == 16) replicate_b = self.n_dim == 16
a_is_fragment = is_fragment(A_local_buf) a_is_fragment = is_fragment(A_local_buf)
e_is_fragment = is_fragment(E_local_buf) e_is_fragment = is_fragment(E_local_buf)
...@@ -584,8 +565,7 @@ class SparseTensorCoreIntrinEmitter: ...@@ -584,8 +565,7 @@ class SparseTensorCoreIntrinEmitter:
B_local_buf.data, B_local_buf.data,
b_local_stride + j * local_size_b + lift(local_size_b) // 2, b_local_stride + j * local_size_b + lift(local_size_b) // 2,
C_local_buf.data, C_local_buf.data,
i * warp_cols * local_size_out + j * local_size_out + i * warp_cols * local_size_out + j * local_size_out + lift(local_size_out) // 2,
lift(local_size_out) // 2,
E_local_buf.data, # metadata E_local_buf.data, # metadata
e_local_stride + i * local_size_e, # metadata offset e_local_stride + i * local_size_e, # metadata offset
self.SPARSE_SELECTOR, # sparse_selector self.SPARSE_SELECTOR, # sparse_selector
...@@ -623,14 +603,13 @@ class SparseTensorCoreIntrinEmitter: ...@@ -623,14 +603,13 @@ class SparseTensorCoreIntrinEmitter:
local_id = local_id_o * 2 + local_id_i local_id = local_id_o * 2 + local_id_i
row, col = T.meta_var(mma_store_index_map(tx, local_id)) row, col = T.meta_var(mma_store_index_map(tx, local_id))
if C_buf_dims == 2: if C_buf_dims == 2:
C_buf[(warp_m * warp_rows + i) * M_DIM + row, C_buf[(warp_m * warp_rows + i) * M_DIM + row, (warp_n * warp_cols + j) * n_dim + col] = C_local_buf[
(warp_n * warp_cols + j) * n_dim + i * (warp_cols * local_size_out) + j * local_size_out + local_id
col] = C_local_buf[i * (warp_cols * local_size_out) + ]
j * local_size_out + local_id]
else: else:
C_buf[warp_m * warp_rows + i, warp_n * warp_cols + j, row, C_buf[warp_m * warp_rows + i, warp_n * warp_cols + j, row, col] = C_local_buf[
col] = C_local_buf[i * (warp_cols * local_size_out) + i * (warp_cols * local_size_out) + j * local_size_out + local_id
j * local_size_out + local_id] ]
@T.macro @T.macro
def _warp_stmatrix_global(C_local_buf, C_buf, thread_binding): def _warp_stmatrix_global(C_local_buf, C_buf, thread_binding):
...@@ -643,15 +622,15 @@ class SparseTensorCoreIntrinEmitter: ...@@ -643,15 +622,15 @@ class SparseTensorCoreIntrinEmitter:
C_buf[ C_buf[
(pid_m * BLOCK_M + warp_m * warp_rows + i) * M_DIM + row, (pid_m * BLOCK_M + warp_m * warp_rows + i) * M_DIM + row,
(pid_n * BLOCK_N + warp_n * warp_cols + j) * n_dim + col, (pid_n * BLOCK_N + warp_n * warp_cols + j) * n_dim + col,
] = C_local_buf[i * warp_cols * local_size_out + j * local_size_out + ] = C_local_buf[i * warp_cols * local_size_out + j * local_size_out + local_id]
local_id]
return (_warp_stmatrix_global(C_local_buf, C_buf, thread_binding) return (
if is_global else _warp_stmatrix_shared(C_local_buf, C_buf, thread_binding)) _warp_stmatrix_global(C_local_buf, C_buf, thread_binding)
if is_global
else _warp_stmatrix_shared(C_local_buf, C_buf, thread_binding)
)
def make_mma_load_layout(self, def make_mma_load_layout(self, local_buf: Buffer, matrix: Literal["A", "B"] = "A") -> T.Fragment:
local_buf: Buffer,
matrix: Literal["A", "B"] = "A") -> T.Fragment:
""" """
Create a layout function for storing MMA results into a fragment buffer. Create a layout function for storing MMA results into a fragment buffer.
This layout is used in conjunction with `inverse_mma_store_layout` to This layout is used in conjunction with `inverse_mma_store_layout` to
...@@ -674,6 +653,7 @@ class SparseTensorCoreIntrinEmitter: ...@@ -674,6 +653,7 @@ class SparseTensorCoreIntrinEmitter:
If `local_buf` is not detected to be a fragment buffer. If `local_buf` is not detected to be a fragment buffer.
""" """
from tilelang.utils import is_fragment from tilelang.utils import is_fragment
assert matrix in ["A", "B"], "matrix should be either A or B" assert matrix in ["A", "B"], "matrix should be either A or B"
matrix_is_a: bool = matrix == "A" matrix_is_a: bool = matrix == "A"
matrix_is_b: bool = matrix == "B" matrix_is_b: bool = matrix == "B"
...@@ -710,11 +690,9 @@ class SparseTensorCoreIntrinEmitter: ...@@ -710,11 +690,9 @@ class SparseTensorCoreIntrinEmitter:
# so the b matrix expected a transposed basic layout # so the b matrix expected a transposed basic layout
transform_func: Callable = None transform_func: Callable = None
if matrix_is_a: if matrix_is_a:
transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a( transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a(j, i)
j, i)
elif matrix_is_b: elif matrix_is_b:
transform_func = transform_func_sr_b if is_sr_axis_order else lambda i, j: transform_func_sr_b( transform_func = transform_func_sr_b if is_sr_axis_order else lambda i, j: transform_func_sr_b(j, i)
j, i)
else: else:
raise ValueError(f"Unsupported matrix {matrix}") raise ValueError(f"Unsupported matrix {matrix}")
...@@ -747,7 +725,8 @@ class SparseTensorCoreIntrinEmitter: ...@@ -747,7 +725,8 @@ class SparseTensorCoreIntrinEmitter:
return local_id return local_id
base_fragment = T.Fragment( base_fragment = T.Fragment(
[micro_size_s, micro_size_r // 2 if matrix_is_a else micro_size_r] if is_sr_axis_order [micro_size_s, micro_size_r // 2 if matrix_is_a else micro_size_r]
if is_sr_axis_order
else [micro_size_r // 2 if matrix_is_a else micro_size_r, micro_size_s], else [micro_size_r // 2 if matrix_is_a else micro_size_r, micro_size_s],
forward_thread_fn=forward_thread, forward_thread_fn=forward_thread,
forward_index_fn=forward_index, forward_index_fn=forward_index,
...@@ -762,31 +741,19 @@ class SparseTensorCoreIntrinEmitter: ...@@ -762,31 +741,19 @@ class SparseTensorCoreIntrinEmitter:
replicate = block_col_warps if matrix_is_a else block_row_warps replicate = block_col_warps if matrix_is_a else block_row_warps
if is_sr_axis_order: if is_sr_axis_order:
warp_fragment = base_fragment.repeat([warp_s, warp_r], warp_fragment = base_fragment.repeat([warp_s, warp_r], repeat_on_thread=False, lower_dim_first=False)
repeat_on_thread=False,
lower_dim_first=False)
if matrix_is_a: if matrix_is_a:
block_fragment = warp_fragment.repeat([block_s, 1], block_fragment = warp_fragment.repeat([block_s, 1], repeat_on_thread=True, lower_dim_first=True).replicate(replicate)
repeat_on_thread=True,
lower_dim_first=True).replicate(replicate)
elif matrix_is_b: elif matrix_is_b:
block_fragment = warp_fragment.replicate(replicate).repeat([block_s, 1], block_fragment = warp_fragment.replicate(replicate).repeat([block_s, 1], repeat_on_thread=True, lower_dim_first=True)
repeat_on_thread=True,
lower_dim_first=True)
else: else:
raise ValueError(f"Unsupported matrix type {matrix}") raise ValueError(f"Unsupported matrix type {matrix}")
else: else:
warp_fragment = base_fragment.repeat([warp_r, warp_s], warp_fragment = base_fragment.repeat([warp_r, warp_s], repeat_on_thread=False, lower_dim_first=True)
repeat_on_thread=False,
lower_dim_first=True)
if matrix_is_a: if matrix_is_a:
block_fragment = warp_fragment.repeat([1, block_s], block_fragment = warp_fragment.repeat([1, block_s], repeat_on_thread=True, lower_dim_first=True).replicate(replicate)
repeat_on_thread=True,
lower_dim_first=True).replicate(replicate)
elif matrix_is_b: elif matrix_is_b:
block_fragment = warp_fragment.replicate(replicate).repeat([1, block_s], block_fragment = warp_fragment.replicate(replicate).repeat([1, block_s], repeat_on_thread=True, lower_dim_first=True)
repeat_on_thread=True,
lower_dim_first=True)
else: else:
raise ValueError(f"Unsupported matrix type {matrix}") raise ValueError(f"Unsupported matrix type {matrix}")
......
...@@ -88,9 +88,22 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): ...@@ -88,9 +88,22 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
is_m_first: bool = False, is_m_first: bool = False,
thread_var: Var | None = None, thread_var: Var | None = None,
): ):
super().__init__(a_dtype, b_dtype, accum_dtype, a_transposed, b_transposed, block_row_warps, super().__init__(
block_col_warps, warp_row_tiles, warp_col_tiles, chunk, reduce_k, a_dtype,
num_elems_per_byte, is_m_first, thread_var) b_dtype,
accum_dtype,
a_transposed,
b_transposed,
block_row_warps,
block_col_warps,
warp_row_tiles,
warp_col_tiles,
chunk,
reduce_k,
num_elems_per_byte,
is_m_first,
thread_var,
)
def _assign_a_shared_layout(self, layout: Layout): def _assign_a_shared_layout(self, layout: Layout):
self.a_shared_layout = layout self.a_shared_layout = layout
...@@ -137,13 +150,7 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): ...@@ -137,13 +150,7 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
else: else:
raise ValueError(f"Unsupported swizzle mode: {layout}") raise ValueError(f"Unsupported swizzle mode: {layout}")
def tcgen05mma(self, def tcgen05mma(self, A_buf: Buffer, B_buf: Buffer, C_local_buf: Buffer, mbar, clear_accum: PrimExpr = False):
A_buf: Buffer,
B_buf: Buffer,
C_local_buf: Buffer,
mbar,
clear_accum: PrimExpr = False):
if is_tensor_memory(A_buf): if is_tensor_memory(A_buf):
return self.tcgen05mma_rs(A_buf, B_buf, C_local_buf, clear_accum) return self.tcgen05mma_rs(A_buf, B_buf, C_local_buf, clear_accum)
...@@ -164,22 +171,20 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): ...@@ -164,22 +171,20 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
elems_in_bits = DataType(self.a_dtype).bits elems_in_bits = DataType(self.a_dtype).bits
elems_in_bytes = elems_in_bits // 8 elems_in_bytes = elems_in_bits // 8
a_swizzle_atom_elems = a_swizzle_mode.swizzle_byte_size() // elems_in_bytes a_swizzle_atom_elems = a_swizzle_mode.swizzle_byte_size() // elems_in_bytes
b_swizzle_atom_elems = n_dim if b_swizzle_mode.is_none( b_swizzle_atom_elems = n_dim if b_swizzle_mode.is_none() else b_swizzle_mode.swizzle_byte_size() // elems_in_bytes
) else b_swizzle_mode.swizzle_byte_size() // elems_in_bytes
accum_dtype_in_bits = DataType(accum_dtype).bits accum_dtype_in_bits = DataType(accum_dtype).bits
meta = self.get_tcgen5_mma_meta(m_dim, n_dim, k_dim) meta = self.get_tcgen5_mma_meta(m_dim, n_dim, k_dim)
if len(meta) != 5: if len(meta) != 5:
raise ValueError( raise ValueError(
f"Unsupported TCGEN5MMA configuration for desc generation: M={m_dim}, N={n_dim}, " f"Unsupported TCGEN5MMA configuration for desc generation: M={m_dim}, N={n_dim}, "
f"K={k_dim}, A dtype={self.a_dtype}, accum dtype={self.accum_dtype}") f"K={k_dim}, A dtype={self.a_dtype}, accum dtype={self.accum_dtype}"
)
atom_m, atom_n, atom_k, enable_ws, enable_2cta = (int(x) for x in meta) atom_m, atom_n, atom_k, enable_ws, enable_2cta = (int(x) for x in meta)
# by default, we utilize non-swizzle layout offset # by default, we utilize non-swizzle layout offset
a_leading_byte_offset = (8 * 8 * elems_in_bytes) if a_is_k_major else (8 * m_dim * a_leading_byte_offset = (8 * 8 * elems_in_bytes) if a_is_k_major else (8 * m_dim * elems_in_bytes)
elems_in_bytes) a_stride_byte_offset = (8 * k_dim * elems_in_bytes) if a_is_k_major else (8 * 8 * elems_in_bytes)
a_stride_byte_offset = (8 * k_dim * elems_in_bytes) if a_is_k_major else (8 * 8 *
elems_in_bytes)
if not a_swizzle_mode.is_none(): if not a_swizzle_mode.is_none():
# swizzle mode doesn't require LBO/SBO to be 1 # swizzle mode doesn't require LBO/SBO to be 1
...@@ -202,11 +207,8 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): ...@@ -202,11 +207,8 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
else: else:
a_stride_byte_offset = 8 * elems_in_bytes * a_swizzle_atom_elems a_stride_byte_offset = 8 * elems_in_bytes * a_swizzle_atom_elems
b_leading_byte_offset = (8 * 8 * elems_in_bytes) if b_is_k_major else (8 * n_dim * b_leading_byte_offset = (8 * 8 * elems_in_bytes) if b_is_k_major else (8 * n_dim * elems_in_bytes)
elems_in_bytes) b_stride_byte_offset = (8 * k_dim * elems_in_bytes) if b_is_k_major else (0 if n_dim == 8 else (8 * 8 * elems_in_bytes))
b_stride_byte_offset = (8 * k_dim *
elems_in_bytes) if b_is_k_major else (0 if n_dim == 8 else
(8 * 8 * elems_in_bytes))
if not b_swizzle_mode.is_none(): if not b_swizzle_mode.is_none():
# swizzle mode doesn't require LBO/SBO to be 1 # swizzle mode doesn't require LBO/SBO to be 1
# https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-leading-dimension-byte-offset # https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-leading-dimension-byte-offset
...@@ -312,21 +314,26 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): ...@@ -312,21 +314,26 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
for ki in T.unroll(0, (k_dim // micro_size_k)): for ki in T.unroll(0, (k_dim // micro_size_k)):
scale_out = T.Select(ki != 0, 1, T.Select(clear_accum, 0, 1)) scale_out = T.Select(ki != 0, 1, T.Select(clear_accum, 0, 1))
A_elem_offset = ( A_elem_offset = (
ki % ak_atom_size (ki % ak_atom_size) * micro_size_k
) * micro_size_k + i * atom_m * a_swizzle_atom_elems + ( + i * atom_m * a_swizzle_atom_elems
ki // ak_atom_size + (ki // ak_atom_size) * m_dim * a_swizzle_atom_elems
) * m_dim * a_swizzle_atom_elems if a_is_k_major else i * atom_m * k_dim + ki * a_swizzle_atom_elems * micro_size_k if a_is_k_major
else i * atom_m * k_dim + ki * a_swizzle_atom_elems * micro_size_k
)
B_elem_offset = (ki // bk_atom_size) * n_dim * b_swizzle_atom_elems + ( B_elem_offset = (
ki % bk_atom_size (ki // bk_atom_size) * n_dim * b_swizzle_atom_elems
) * micro_size_k + j * atom_n * b_swizzle_atom_elems if b_is_k_major else ( + (ki % bk_atom_size) * micro_size_k
ki * b_swizzle_atom_elems * micro_size_k + j * atom_n * + j * atom_n * b_swizzle_atom_elems
(k_dim if n_dim // b_swizzle_atom_elems > 1 else 1)) if b_is_k_major
else (
ki * b_swizzle_atom_elems * micro_size_k + j * atom_n * (k_dim if n_dim // b_swizzle_atom_elems > 1 else 1)
)
)
A_byte_offset = A_elem_offset * elems_in_bytes A_byte_offset = A_elem_offset * elems_in_bytes
B_byte_offset = B_elem_offset * elems_in_bytes B_byte_offset = B_elem_offset * elems_in_bytes
C_offset = (i * n_dim + j * tmem_col_step C_offset = (i * n_dim + j * tmem_col_step) * accum_dtype_in_bits // 32 # 32 bits per tmem bank
) * accum_dtype_in_bits // 32 # 32 bits per tmem bank
T.ptx_tcgen05_mma_ss( T.ptx_tcgen05_mma_ss(
a_dtype_abbrv, a_dtype_abbrv,
...@@ -373,8 +380,7 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): ...@@ -373,8 +380,7 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
""" """
assert is_tensor_memory(tmem_buf), "tmem_buf must reside in tensor memory (shared.tmem)" assert is_tensor_memory(tmem_buf), "tmem_buf must reside in tensor memory (shared.tmem)"
if len(tmem_buf.shape) != 2: if len(tmem_buf.shape) != 2:
raise ValueError( raise ValueError(f"TCGEN5MMA expects a 2-D tensor-memory buffer, got shape {tmem_buf.shape}")
f"TCGEN5MMA expects a 2-D tensor-memory buffer, got shape {tmem_buf.shape}")
m = int(tmem_buf.shape[0]) m = int(tmem_buf.shape[0])
n = int(tmem_buf.shape[1]) n = int(tmem_buf.shape[1])
...@@ -382,14 +388,13 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): ...@@ -382,14 +388,13 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
meta = self.get_tcgen5_mma_meta(m, n, k) meta = self.get_tcgen5_mma_meta(m, n, k)
if len(meta) != 5: if len(meta) != 5:
raise ValueError(f"Unsupported TCGEN5MMA configuration: M={m}, N={n}, K={k}, " raise ValueError(
f"A dtype={self.a_dtype}, accum dtype={self.accum_dtype}") f"Unsupported TCGEN5MMA configuration: M={m}, N={n}, K={k}, A dtype={self.a_dtype}, accum dtype={self.accum_dtype}"
)
atom_m, atom_n, _, _, _ = (int(x) for x in meta) atom_m, atom_n, _, _, _ = (int(x) for x in meta)
if m % atom_m != 0 or n % atom_n != 0: if m % atom_m != 0 or n % atom_n != 0:
raise ValueError( raise ValueError(f"Invalid TCGEN5MMA store layout for shape ({m}, {n}) with atoms ({atom_m}, {atom_n})")
f"Invalid TCGEN5MMA store layout for shape ({m}, {n}) with atoms ({atom_m}, {atom_n})"
)
def forward(i: PrimExpr, j: PrimExpr): def forward(i: PrimExpr, j: PrimExpr):
atom_idx = (i // atom_m) + (j // atom_n) * (m // atom_m) atom_idx = (i // atom_m) + (j // atom_n) * (m // atom_m)
...@@ -422,11 +427,11 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): ...@@ -422,11 +427,11 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
return Layout([m, n], forward) return Layout([m, n], forward)
def get_tcgen5_mma_meta(self, m: int, n: int, k: int): def get_tcgen5_mma_meta(self, m: int, n: int, k: int):
return _ffi_api.get_tcgen5_mma_meta( return _ffi_api.get_tcgen5_mma_meta(int(m), int(n), int(k), DataType(self.a_dtype), DataType(self.accum_dtype))
int(m), int(n), int(k), DataType(self.a_dtype), DataType(self.accum_dtype))
def get_tcgen5_instr_desc(self, atom_m: int, atom_n: int, atom_k: int, a_is_k_major: bool, def get_tcgen5_instr_desc(
b_is_k_major: bool, scale_in_a: int, scale_in_b: int) -> PrimExpr: self, atom_m: int, atom_n: int, atom_k: int, a_is_k_major: bool, b_is_k_major: bool, scale_in_a: int, scale_in_b: int
) -> PrimExpr:
desc = _ffi_api.get_tcgen5_instr_desc( desc = _ffi_api.get_tcgen5_instr_desc(
atom_m, atom_m,
atom_n, atom_n,
......
...@@ -10,7 +10,7 @@ from .mma_layout import ( ...@@ -10,7 +10,7 @@ from .mma_layout import (
mma_store_32x8_to_shared_16x16_layout, mma_store_32x8_to_shared_16x16_layout,
mma_store_32x2_to_shared_8x8_layout_fp64, mma_store_32x2_to_shared_8x8_layout_fp64,
) )
from .mfma_layout import (thread_id_shared_access_64x4_to_16x16_layout_C_n_m) from .mfma_layout import thread_id_shared_access_64x4_to_16x16_layout_C_n_m
from .mma_layout import get_swizzle_layout # noqa: F401 from .mma_layout import get_swizzle_layout # noqa: F401
from .mma_layout import make_mma_swizzle_layout # noqa: F401 from .mma_layout import make_mma_swizzle_layout # noqa: F401
......
...@@ -15,9 +15,11 @@ from tilelang.layout import ( ...@@ -15,9 +15,11 @@ from tilelang.layout import (
make_linear_layout, make_linear_layout,
) )
from tvm.runtime import convert from tvm.runtime import convert
from tilelang.intrinsics.mma_layout import (shared_16x8_to_mma_32x4_layout_sr_a, from tilelang.intrinsics.mma_layout import (
shared_16x16_to_mma_32x8_layout_sr_a, shared_16x8_to_mma_32x4_layout_sr_a,
shared_16x32_to_mma_32x16_layout_sr_a) shared_16x16_to_mma_32x8_layout_sr_a,
shared_16x32_to_mma_32x16_layout_sr_a,
)
lift = convert lift = convert
...@@ -96,9 +98,22 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): ...@@ -96,9 +98,22 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
is_m_first: bool | None = False, is_m_first: bool | None = False,
thread_var: Var | None = None, thread_var: Var | None = None,
): ):
super().__init__(a_dtype, b_dtype, accum_dtype, a_transposed, b_transposed, block_row_warps, super().__init__(
block_col_warps, warp_row_tiles, warp_col_tiles, chunk, reduce_k, a_dtype,
num_elems_per_byte, is_m_first, thread_var) b_dtype,
accum_dtype,
a_transposed,
b_transposed,
block_row_warps,
block_col_warps,
warp_row_tiles,
warp_col_tiles,
chunk,
reduce_k,
num_elems_per_byte,
is_m_first,
thread_var,
)
self._initialize_wgmma_prefix(self.n_dim) self._initialize_wgmma_prefix(self.n_dim)
def _assign_a_shared_layout(self, layout: Layout): def _assign_a_shared_layout(self, layout: Layout):
...@@ -112,12 +127,12 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): ...@@ -112,12 +127,12 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
def _initialize_wgmma_prefix(self, n_dim: int = 16): def _initialize_wgmma_prefix(self, n_dim: int = 16):
inst_m, inst_n = 64, gcd(self.warp_col_tiles, 256) inst_m, inst_n = 64, gcd(self.warp_col_tiles, 256)
assert inst_n % 8 == 0, ( assert inst_n % 8 == 0, (
f"inst_n must be a multiple of 8, got {inst_n} " f"inst_n must be a multiple of 8, got {inst_n} (block_col_warps={self.block_col_warps}, warp_col_tiles={self.warp_col_tiles})"
f"(block_col_warps={self.block_col_warps}, warp_col_tiles={self.warp_col_tiles})") )
# Validate inst_n: Hopper WGMMA supports n in [8, 256] and multiple of 8 # Validate inst_n: Hopper WGMMA supports n in [8, 256] and multiple of 8
assert 8 <= inst_n <= 256, ( assert 8 <= inst_n <= 256, (
f"inst_n must be within [8, 256], got {inst_n} " f"inst_n must be within [8, 256], got {inst_n} (block_col_warps={self.block_col_warps}, warp_col_tiles={self.warp_col_tiles})"
f"(block_col_warps={self.block_col_warps}, warp_col_tiles={self.warp_col_tiles})") )
# 256 bits per instruction # 256 bits per instruction
inst_k = 256 // DataType(self.a_dtype).bits inst_k = 256 // DataType(self.a_dtype).bits
self.wgmma_inst_m = inst_m self.wgmma_inst_m = inst_m
...@@ -160,13 +175,9 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): ...@@ -160,13 +175,9 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
else: else:
raise ValueError(f"Unsupported swizzle mode: {layout}") raise ValueError(f"Unsupported swizzle mode: {layout}")
def wgmma(self, def wgmma(
A_region: BufferRegion, self, A_region: BufferRegion, B_region: BufferRegion, C_region: BufferRegion, clear_accum: PrimExpr = False, wg_wait: int = 0
B_region: BufferRegion, ):
C_region: BufferRegion,
clear_accum: PrimExpr = False,
wg_wait: int = 0):
if is_fragment(A_region): if is_fragment(A_region):
return self.wgmma_rs(A_region, B_region, C_region, clear_accum, wg_wait) return self.wgmma_rs(A_region, B_region, C_region, clear_accum, wg_wait)
...@@ -195,16 +206,13 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): ...@@ -195,16 +206,13 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
elems_in_bytes = elems_in_bits // 8 elems_in_bytes = elems_in_bits // 8
a_swizzle_atom_elems = a_swizzle_mode.swizzle_byte_size() // elems_in_bytes a_swizzle_atom_elems = a_swizzle_mode.swizzle_byte_size() // elems_in_bytes
b_swizzle_atom_elems = n_dim if b_swizzle_mode.is_none( b_swizzle_atom_elems = n_dim if b_swizzle_mode.is_none() else b_swizzle_mode.swizzle_byte_size() // elems_in_bytes
) else b_swizzle_mode.swizzle_byte_size() // elems_in_bytes
accum_bits = DataType(accum_dtype).bits accum_bits = DataType(accum_dtype).bits
accum_regs = ((m_dim // 64) * warp_cols * local_size_out * accum_bits + 31) // 32 accum_regs = ((m_dim // 64) * warp_cols * local_size_out * accum_bits + 31) // 32
# by default, we utilize non-swizzle layout offset # by default, we utilize non-swizzle layout offset
a_leading_byte_offset = (8 * 8 * elems_in_bytes) if a_is_k_major else (8 * m_dim * a_leading_byte_offset = (8 * 8 * elems_in_bytes) if a_is_k_major else (8 * m_dim * elems_in_bytes)
elems_in_bytes) a_stride_byte_offset = (8 * k_dim * elems_in_bytes) if a_is_k_major else (8 * 8 * elems_in_bytes)
a_stride_byte_offset = (8 * k_dim * elems_in_bytes) if a_is_k_major else (8 * 8 *
elems_in_bytes)
if not a_swizzle_mode.is_none(): if not a_swizzle_mode.is_none():
# swizzle mode doesn't require LBO/SBO to be 1 # swizzle mode doesn't require LBO/SBO to be 1
...@@ -220,19 +228,15 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): ...@@ -220,19 +228,15 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
if a_m_axis_atoms <= 1: if a_m_axis_atoms <= 1:
a_leading_byte_offset = 0 a_leading_byte_offset = 0
else: else:
a_leading_byte_offset = 8 * a_swizzle_mode.swizzle_atom_size() * ( a_leading_byte_offset = 8 * a_swizzle_mode.swizzle_atom_size() * (a_swizzle_mode.swizzle_byte_size() // elems_in_bytes)
a_swizzle_mode.swizzle_byte_size() // elems_in_bytes)
if a_m_axis_atoms <= 1: if a_m_axis_atoms <= 1:
a_stride_byte_offset = 8 * elems_in_bytes * m_dim a_stride_byte_offset = 8 * elems_in_bytes * m_dim
else: else:
a_stride_byte_offset = 8 * elems_in_bytes * a_swizzle_atom_elems a_stride_byte_offset = 8 * elems_in_bytes * a_swizzle_atom_elems
b_leading_byte_offset = (8 * 8 * elems_in_bytes) if b_is_k_major else (8 * n_dim * b_leading_byte_offset = (8 * 8 * elems_in_bytes) if b_is_k_major else (8 * n_dim * elems_in_bytes)
elems_in_bytes) b_stride_byte_offset = (8 * k_dim * elems_in_bytes) if b_is_k_major else (0 if n_dim == 8 else (8 * 8 * elems_in_bytes))
b_stride_byte_offset = (8 * k_dim *
elems_in_bytes) if b_is_k_major else (0 if n_dim == 8 else
(8 * 8 * elems_in_bytes))
if not b_swizzle_mode.is_none(): if not b_swizzle_mode.is_none():
# swizzle mode doesn't require LBO/SBO to be 1 # swizzle mode doesn't require LBO/SBO to be 1
# https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-leading-dimension-byte-offset # https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-leading-dimension-byte-offset
...@@ -275,12 +279,8 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): ...@@ -275,12 +279,8 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
desc_a = T.alloc_wgmma_desc() desc_a = T.alloc_wgmma_desc()
desc_b = T.alloc_wgmma_desc() desc_b = T.alloc_wgmma_desc()
T.initialize_wgmma_descriptor(desc_a, A_ptr, a_swizzle_mode, T.initialize_wgmma_descriptor(desc_a, A_ptr, a_swizzle_mode, int(a_leading_byte_offset >> 4), int(a_stride_byte_offset >> 4))
int(a_leading_byte_offset >> 4), T.initialize_wgmma_descriptor(desc_b, B_ptr, b_swizzle_mode, int(b_leading_byte_offset >> 4), int(b_stride_byte_offset >> 4))
int(a_stride_byte_offset >> 4))
T.initialize_wgmma_descriptor(desc_b, B_ptr, b_swizzle_mode,
int(b_leading_byte_offset >> 4),
int(b_stride_byte_offset >> 4))
T.warpgroup_fence_operand(C_buf, num_regs=accum_regs) T.warpgroup_fence_operand(C_buf, num_regs=accum_regs)
T.warpgroup_arrive() T.warpgroup_arrive()
...@@ -291,21 +291,41 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): ...@@ -291,21 +291,41 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
warp_i = (warp_m // 4) * num_inst_m + i warp_i = (warp_m // 4) * num_inst_m + i
warp_j = warp_n * num_inst_n + j warp_j = warp_n * num_inst_n + j
A_offset = ( A_offset = (
ki % ak_atom_size (ki % ak_atom_size) * micro_size_k
) * micro_size_k + warp_i * 64 * a_swizzle_atom_elems + ( + warp_i * 64 * a_swizzle_atom_elems
ki // ak_atom_size + (ki // ak_atom_size) * m_dim * a_swizzle_atom_elems
) * m_dim * a_swizzle_atom_elems if a_is_k_major else warp_i * 64 * k_dim + ki * a_swizzle_atom_elems * micro_size_k if a_is_k_major
B_offset = (ki // bk_atom_size) * n_dim * b_swizzle_atom_elems + ( else warp_i * 64 * k_dim + ki * a_swizzle_atom_elems * micro_size_k
ki % bk_atom_size )
) * micro_size_k + warp_j * wgmma_inst_n * b_swizzle_atom_elems if b_is_k_major else ( B_offset = (
ki * b_swizzle_atom_elems * micro_size_k + warp_j * wgmma_inst_n * (ki // bk_atom_size) * n_dim * b_swizzle_atom_elems
(k_dim if n_dim // b_swizzle_atom_elems > 1 else 1)) + (ki % bk_atom_size) * micro_size_k
+ warp_j * wgmma_inst_n * b_swizzle_atom_elems
if b_is_k_major
else (
ki * b_swizzle_atom_elems * micro_size_k
+ warp_j * wgmma_inst_n * (k_dim if n_dim // b_swizzle_atom_elems > 1 else 1)
)
)
C_offset = i * warp_cols * local_size_out + j * warp_cols * local_size_out // num_inst_n # 4 warps as an unit C_offset = i * warp_cols * local_size_out + j * warp_cols * local_size_out // num_inst_n # 4 warps as an unit
T.ptx_wgmma_ss(accum_dtype, wgmma_prefix, a_is_k_major, b_is_k_major, T.ptx_wgmma_ss(
a_dtype_abbrv, b_dtype_abbrv, accum_dtype_abbrv, desc_a.data, accum_dtype,
(A_offset * elems_in_bytes) >> 4, desc_b.data, wgmma_prefix,
(B_offset * elems_in_bytes) >> 4, C_buf.data, C_offset, a_is_k_major,
scale_out, scale_in_a, scale_in_b) b_is_k_major,
a_dtype_abbrv,
b_dtype_abbrv,
accum_dtype_abbrv,
desc_a.data,
(A_offset * elems_in_bytes) >> 4,
desc_b.data,
(B_offset * elems_in_bytes) >> 4,
C_buf.data,
C_offset,
scale_out,
scale_in_a,
scale_in_b,
)
T.warpgroup_commit_batch() T.warpgroup_commit_batch()
if wg_wait >= 0: if wg_wait >= 0:
...@@ -314,12 +334,9 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): ...@@ -314,12 +334,9 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
return _warp_mma(A_ptr, B_ptr, C_buf) return _warp_mma(A_ptr, B_ptr, C_buf)
def wgmma_rs(self, def wgmma_rs(
A_region: BufferRegion, self, A_region: BufferRegion, B_region: BufferRegion, C_region: BufferRegion, clear_accum: PrimExpr = False, wg_wait: int = 0
B_region: BufferRegion, ):
C_region: BufferRegion,
clear_accum: PrimExpr = False,
wg_wait: int = 0):
local_size_a = self.local_size_a local_size_a = self.local_size_a
local_size_out = self.local_size_out local_size_out = self.local_size_out
a_dtype_abbrv = self.a_dtype_abbrv a_dtype_abbrv = self.a_dtype_abbrv
...@@ -344,14 +361,10 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): ...@@ -344,14 +361,10 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
b_is_k_major = self.b_transposed b_is_k_major = self.b_transposed
b_swizzle_mode = self._determinate_swizzle_mode(B_region, self.b_shared_layout) b_swizzle_mode = self._determinate_swizzle_mode(B_region, self.b_shared_layout)
b_swizzle_atom_elems = n_dim if b_swizzle_mode.is_none( b_swizzle_atom_elems = n_dim if b_swizzle_mode.is_none() else b_swizzle_mode.swizzle_byte_size() // elems_in_bytes
) else b_swizzle_mode.swizzle_byte_size() // elems_in_bytes
b_leading_byte_offset = (8 * 8 * elems_in_bytes) if b_is_k_major else (8 * n_dim * elems_in_bytes)
b_leading_byte_offset = (8 * 8 * elems_in_bytes) if b_is_k_major else (8 * n_dim * b_stride_byte_offset = (8 * k_dim * elems_in_bytes) if b_is_k_major else (0 if n_dim == 8 else (8 * 8 * elems_in_bytes))
elems_in_bytes)
b_stride_byte_offset = (8 * k_dim *
elems_in_bytes) if b_is_k_major else (0 if n_dim == 8 else
(8 * 8 * elems_in_bytes))
if not b_swizzle_mode.is_none(): if not b_swizzle_mode.is_none():
# swizzle mode doesn't require LBO/SBO to be 1 # swizzle mode doesn't require LBO/SBO to be 1
# https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-leading-dimension-byte-offset # https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-leading-dimension-byte-offset
...@@ -390,9 +403,7 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): ...@@ -390,9 +403,7 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
tx, warp_n, warp_m = self.extract_thread_binding(thread_binding) tx, warp_n, warp_m = self.extract_thread_binding(thread_binding)
desc_b = T.alloc_wgmma_desc() desc_b = T.alloc_wgmma_desc()
T.initialize_wgmma_descriptor(desc_b, B_ptr, b_swizzle_mode, T.initialize_wgmma_descriptor(desc_b, B_ptr, b_swizzle_mode, int(b_leading_byte_offset >> 4), int(b_stride_byte_offset >> 4))
int(b_leading_byte_offset >> 4),
int(b_stride_byte_offset >> 4))
T.warpgroup_fence_operand(A_buf, num_regs=a_regs) T.warpgroup_fence_operand(A_buf, num_regs=a_regs)
T.warpgroup_fence_operand(C_buf, num_regs=accum_regs) T.warpgroup_fence_operand(C_buf, num_regs=accum_regs)
T.warpgroup_arrive() T.warpgroup_arrive()
...@@ -405,11 +416,15 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): ...@@ -405,11 +416,15 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
A_offset = ki * warp_rows * local_size_a + i * local_size_a A_offset = ki * warp_rows * local_size_a + i * local_size_a
B_offset = ( B_offset = (
ki // bk_atom_size (ki // bk_atom_size) * n_dim * b_swizzle_atom_elems
) * n_dim * b_swizzle_atom_elems + warp_j * wgmma_inst_n * b_swizzle_atom_elems + ( + warp_j * wgmma_inst_n * b_swizzle_atom_elems
ki % bk_atom_size) * micro_size_k if b_is_k_major else ( + (ki % bk_atom_size) * micro_size_k
ki * b_swizzle_atom_elems * micro_size_k + warp_j * wgmma_inst_n * if b_is_k_major
(k_dim if n_dim // b_swizzle_atom_elems > 1 else 1)) else (
ki * b_swizzle_atom_elems * micro_size_k
+ warp_j * wgmma_inst_n * (k_dim if n_dim // b_swizzle_atom_elems > 1 else 1)
)
)
C_offset = i * warp_cols * local_size_out + j * warp_cols * local_size_out // num_inst_n # 4 warps as an unit C_offset = i * warp_cols * local_size_out + j * warp_cols * local_size_out // num_inst_n # 4 warps as an unit
T.ptx_wgmma_rs( T.ptx_wgmma_rs(
accum_dtype, accum_dtype,
...@@ -460,6 +475,7 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): ...@@ -460,6 +475,7 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
If `local_buf` is not detected to be a fragment buffer. If `local_buf` is not detected to be a fragment buffer.
""" """
from tilelang.utils import is_fragment from tilelang.utils import is_fragment
assert matrix in ["A"], "matrix should be A for WGMMA" assert matrix in ["A"], "matrix should be A for WGMMA"
dtype = self.a_dtype dtype = self.a_dtype
dtype_bits = DataType(dtype).bits dtype_bits = DataType(dtype).bits
...@@ -488,8 +504,7 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): ...@@ -488,8 +504,7 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
# the layout of mma.sync is row.col. # the layout of mma.sync is row.col.
# so the b matrix expected a transposed basic layout # so the b matrix expected a transposed basic layout
transform_func: Callable = None transform_func: Callable = None
transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a( transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a(j, i)
j, i)
assert is_fragment(local_buf), f"local_buf must be a fragment, but got {local_buf.scope()}" assert is_fragment(local_buf), f"local_buf must be a fragment, but got {local_buf.scope()}"
...@@ -531,20 +546,12 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): ...@@ -531,20 +546,12 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
replicate = block_col_warps replicate = block_col_warps
if is_sr_axis_order: if is_sr_axis_order:
warp_fragment = base_fragment.repeat([block_s, 1], warp_fragment = base_fragment.repeat([block_s, 1], repeat_on_thread=True, lower_dim_first=False).replicate(replicate)
repeat_on_thread=True, block_fragment = warp_fragment.repeat([warp_s, warp_r], repeat_on_thread=False, lower_dim_first=False)
lower_dim_first=False).replicate(replicate)
block_fragment = warp_fragment.repeat([warp_s, warp_r],
repeat_on_thread=False,
lower_dim_first=False)
else: else:
# rs condition, transposed_a matrix # rs condition, transposed_a matrix
warp_fragment = base_fragment.repeat([1, block_s], warp_fragment = base_fragment.repeat([1, block_s], repeat_on_thread=True, lower_dim_first=False).replicate(replicate)
repeat_on_thread=True, block_fragment = warp_fragment.repeat([warp_r, warp_s], repeat_on_thread=False, lower_dim_first=True)
lower_dim_first=False).replicate(replicate)
block_fragment = warp_fragment.repeat([warp_r, warp_s],
repeat_on_thread=False,
lower_dim_first=True)
return block_fragment return block_fragment
......
...@@ -7,23 +7,19 @@ from tilelang import _ffi_api ...@@ -7,23 +7,19 @@ from tilelang import _ffi_api
@tvm_ffi.register_object("tl.Fill") @tvm_ffi.register_object("tl.Fill")
class Fill(Node, Scriptable): class Fill(Node, Scriptable): ...
...
@tvm_ffi.register_object("tl.AtomicAdd") @tvm_ffi.register_object("tl.AtomicAdd")
class AtomicAdd(Node, Scriptable): class AtomicAdd(Node, Scriptable): ...
...
@tvm_ffi.register_object("tl.Copy") @tvm_ffi.register_object("tl.Copy")
class Copy(Node, Scriptable): class Copy(Node, Scriptable): ...
...
@tvm_ffi.register_object("tl.Conv2DIm2Col") @tvm_ffi.register_object("tl.Conv2DIm2Col")
class Conv2DIm2ColOp(Node, Scriptable): class Conv2DIm2ColOp(Node, Scriptable): ...
...
@tvm_ffi.register_object("tl.GemmWarpPolicy") @tvm_ffi.register_object("tl.GemmWarpPolicy")
...@@ -32,10 +28,8 @@ class GemmWarpPolicy(Node, Scriptable): ...@@ -32,10 +28,8 @@ class GemmWarpPolicy(Node, Scriptable):
m_warp: int m_warp: int
n_warp: int n_warp: int
def compute_warp_partition(self, M: int, N: int, block_size: int, target: Target, def compute_warp_partition(self, M: int, N: int, block_size: int, target: Target, is_wgmma: bool):
is_wgmma: bool): _ffi_api.GemmWarpPolicyComputeWarpPartition(self, int(M), int(N), int(block_size), target, is_wgmma)
_ffi_api.GemmWarpPolicyComputeWarpPartition(self, int(M), int(N), int(block_size), target,
is_wgmma)
return self.m_warp, self.n_warp return self.m_warp, self.n_warp
...@@ -45,48 +39,38 @@ class GemmSPWarpPolicy(Node, Scriptable): ...@@ -45,48 +39,38 @@ class GemmSPWarpPolicy(Node, Scriptable):
m_warp: int m_warp: int
n_warp: int n_warp: int
def compute_warp_partition(self, M: int, N: int, block_size: int, target: Target, def compute_warp_partition(self, M: int, N: int, block_size: int, target: Target, is_wgmma: bool, bits: int):
is_wgmma: bool, bits: int): _ffi_api.GemmSPWarpPolicyComputeWarpPartition(self, int(M), int(N), int(block_size), target, is_wgmma, bits)
_ffi_api.GemmSPWarpPolicyComputeWarpPartition(self, int(M), int(N), int(block_size), target,
is_wgmma, bits)
return self.m_warp, self.n_warp return self.m_warp, self.n_warp
@tvm_ffi.register_object("tl.Gemm") @tvm_ffi.register_object("tl.Gemm")
class Gemm(Node, Scriptable): class Gemm(Node, Scriptable): ...
...
@tvm_ffi.register_object("tl.GemmSP") @tvm_ffi.register_object("tl.GemmSP")
class GemmSP(Node, Scriptable): class GemmSP(Node, Scriptable): ...
...
@tvm_ffi.register_object("tl.FinalizeReducerOp") @tvm_ffi.register_object("tl.FinalizeReducerOp")
class FinalizeReducerOp(Node, Scriptable): class FinalizeReducerOp(Node, Scriptable): ...
...
@tvm_ffi.register_object("tl.ParallelOp") @tvm_ffi.register_object("tl.ParallelOp")
class ParallelOp(Node, Scriptable): class ParallelOp(Node, Scriptable): ...
...
@tvm_ffi.register_object("tl.ReduceOp") @tvm_ffi.register_object("tl.ReduceOp")
class ReduceOp(Node, Scriptable): class ReduceOp(Node, Scriptable): ...
...
@tvm_ffi.register_object("tl.CumSumOp") @tvm_ffi.register_object("tl.CumSumOp")
class CumSumOp(Node, Scriptable): class CumSumOp(Node, Scriptable): ...
...
@tvm_ffi.register_object("tl.RegionOp") @tvm_ffi.register_object("tl.RegionOp")
class RegionOp(Node, Scriptable): class RegionOp(Node, Scriptable): ...
...
@tvm_ffi.register_object("tl.ReduceType") @tvm_ffi.register_object("tl.ReduceType")
class ReduceType(Node, Scriptable): class ReduceType(Node, Scriptable): ...
...
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment