Commit cd191889 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Carver] Introduce a tile-structure based cost model for auto tuning (#70)

* [Enhancement] Add VectorizeLoop function and update imports for compatibility

* [CI][Test] Improve test cases for vectorization and fix typos in parser comments

* lint fix

* Fix incorrect module reference for VectorizeLoop transformation

* Refactor vectorize_loop transformation by removing unused extent mutation logic

* [Enhancement] Add support for FP8 data types and global barriers in CUDA codegen

* Fix formatting in CUDA FP8 header file for consistency

* Refactor CI workflow to use 'tilelang_ci' virtual environment and update CUDA type printing for better clarity

* Update submodule 'tvm' to latest commit for improved functionality

* Refactor execution backend references from 'dl_pack' to 'dlpack' for consistency and clarity; add apply_simplify function to simplify PrimFunc or IRModule.

* Refactor CUDA code for improved readability; clean up formatting and remove unnecessary whitespace in multiple files.

* Refactor import statement in test_tilelang_kernel_dequantize_gemm.py to use 'tilelang.language' for consistency

* Add CUDA requirements to FP8 test cases and update references for clarity

* Add a blank line for improved readability in test_tilelang_kernel_fp8_gemm_mma.py

* Fix data type in reference result calculation for consistency in test_tilelang_kernel_gemm_mma_intrinsic.py

* Add CUDA requirements and FP8 test cases for matmul and gemv simulations

* Remove debug print statements and use tilelang's testing assertion for result validation in test_tilelang_kernel_gemm_mma_intrinsic.py

* Remove outdated comment regarding FP8 tests in test_tilelang_kernel_gemv_simt.py

* Add BF16 support to matrix multiplication and introduce corresponding test cases

* Add a blank line for improved readability in BF16 GEMM test

* Update acknowledgements in README to include supervision by Zhi Yang at Peking University

* enhance acknowledgement

* Replace tutorial on memory layout optimization with new tutorial on writing high-performance kernels with thread primitives

* Update subproject commit for TVM dependency

* Update subproject commit for TVM dependency

* Add int4_t type and functions for packing char values in CUDA common header

* Add plot_layout example and implement GetForwardVars method in layout classes

* Refactor code for improved readability by adjusting line breaks and formatting in layout and test files

* Fix formatting by removing unnecessary line break in layout.h

* Refactor make_int4 function for improved readability by adjusting parameter formatting

* Add legend to plot_layout for improved clarity of thread and local IDs

* Remove unnecessary dependencies from requirements files for cleaner setup

* Remove flash_mha.py and add .gitkeep to deepseek_mla directory

* Add build requirements and update installation scripts for improved setup

* Introduce carver

* Refactor imports and improve code formatting for consistency

* Add unit tests for carver recommendation hints

* lint fix

* Enhance ElementwiseTemplate and BaseTemplate with detailed docstrings for improved code documentation and clarity

* Refactor import statements and clean up whitespace in template files for improved readability

* Add README.md for Carver framework with usage examples and architecture support
parent 2411fa28
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from .tir import get_analyzer_by_tir # noqa: F401
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from collections import OrderedDict
from typing import Dict, List
from tvm import arith
class Statement():
def __init__(self, output: str, dependent_region: dict, var_map: OrderedDict,
range_map: OrderedDict):
self.output = output
self.dependent_region = dependent_region
self.var_map = var_map
self.range_map = range_map
def _merge_two_bounds(x: arith.ConstIntBound, y: arith.ConstIntBound):
return arith.ConstIntBound(min(x.min_value, y.min_value), max(x.max_value, y.max_value))
class InputShapeInference():
def __init__(self, deps: List[Statement]):
self.deps = deps
def _infer(self, shape: Dict[str, List[arith.ConstIntBound]], rstep: Dict[str, int]):
shape = shape.copy()
ana = arith.Analyzer()
for dep in reversed(self.deps):
for var, bound in zip(dep.var_map.values(), shape[dep.output]):
ana.update(var, bound)
for var, bound in dep.range_map.items():
if var.name in rstep:
bound = arith.ConstIntBound(0, min(bound.max_value, rstep[var.name] - 1))
ana.update(var, bound)
for name, regions in dep.dependent_region.items():
for region in regions:
bounds = [ana.const_int_bound(index) for index in region]
if name in shape: # simply merge two bounds
bounds = [_merge_two_bounds(x, y) for x, y in zip(shape[name], bounds)]
shape[name] = bounds
for name, bounds in shape.items():
shape[name] = [c.max_value - c.min_value + 1 for c in bounds]
return shape
def infer(self, shape, rstep: Dict[str, int] = None):
if rstep is None:
rstep = {}
if isinstance(shape, (list, tuple)):
shape = {"output0": [arith.ConstIntBound(0, val - 1) for val in shape]}
shape = self._infer(shape, rstep)
return shape
def get_input_exprs(self, output_exprs):
result = output_exprs.copy()
ana = arith.Analyzer()
for dep in reversed(self.deps):
for var, expr in zip(dep.var_map.values(), result[dep.output]):
ana.bind(var, expr)
for var in dep.range_map:
ana.bind(var, 0)
for name, regions in dep.dependent_region.items():
if name in result:
continue
region = regions[0]
input_expr = [ana.simplify(index) for index in region]
result[name] = input_expr
return result
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from typing import Dict, List, Tuple, Set, Mapping
from tvm.tir.schedule.schedule import BlockRV
from tvm.ir import structural_equal
from tvm import arith, tir
class Statement:
def __init__(self, block_analyzer, block: BlockRV):
self.block_analyzer = block_analyzer
self.block = block
# assume one tir block only has one output buffer
self.dep_name = block_analyzer.get_output_buffers(block)[0].name
self.dependent_region = _extract_dependent_region(block_analyzer, block)
self.reverse_bound_inference = {}
def make_reverse(self, input_name: str, input_iter: List[tir.PrimExpr]):
if len(self.block_analyzer.get_reduce_axis(self.block)) > 0:
return None
if len(self.dependent_region[input_name]) != 1:
return None
indices = self.dependent_region[input_name][0]
iter_map_range = {
_iter.var: _iter.dom for _iter in self.block_analyzer.get_spatial_axis(self.block)
}
iter_map_result = arith.detect_iter_map(
indices,
iter_map_range,
check_level=arith.iter_affine_map.IterMapLevel.Surjective,
simplify_trivial_iterators=False,
)
if len(iter_map_result.errors) > 0:
return None
results = arith.iter_affine_map.inverse_affine_iter_map(iter_map_result.indices, input_iter)
output_indices = []
for _iter in self.block_analyzer.get_spatial_axis(self.block):
if _iter.var in results:
output_indices.append(results[_iter.var])
else:
# not Bijective mapping case
output_indices.append(tir.Var("undefined", dtype="int32") % int(_iter.dom.extent))
return output_indices
def _merge_two_bounds(x: arith.ConstIntBound, y: arith.ConstIntBound):
return arith.ConstIntBound(min(x.min_value, y.min_value), max(x.max_value, y.max_value))
class TensorDepNode(object):
"""
For tensor dependency analysis.
"""
def __init__(self, name):
self.name = name
self._next = []
self._prev = []
def add_next(self, node):
self._next.append(node)
self.deduplicate(self._next)
def add_prev(self, node):
self._prev.append(node)
self.deduplicate(self._prev)
def deduplicate(self, lst):
seen = set()
lst[:] = [n for n in lst if not (n in seen or seen.add(n))]
def __str__(self):
return self.name
def __repr__(self):
return self.name
class DependencyAnalysis(object):
def __init__(self, deps):
self.deps = deps
# issue: duplicate name when we have two same ops.
self.name2dep = self._construct_unique_name2dep(deps)
self.mapping = {} # name -> TensorDepNode
def _construct_unique_name2dep(self, deps):
"""
This is a workaround for the issue that we have two same ops' fuse case.
See https://github.com/apache/tvm/issues/16433
"""
_names: Set = set()
name2dep: Mapping = {}
for dep in deps:
output_buffer = dep.block_analyzer.get_output_buffers(dep.block)[0]
base_name = output_buffer.name
if base_name not in _names:
_names.add(base_name)
else:
i = 1
while f"{base_name}_{i}" in _names:
i += 1
base_name = f"{base_name}_{i}"
_names.add(base_name)
name2dep[base_name] = dep
return name2dep
def get_or_create_node(self, name):
if name not in self.mapping:
self.mapping[name] = TensorDepNode(name)
return self.mapping[name]
def traverse_dependencies(self, compute):
if isinstance(compute, Statement):
node = self.get_or_create_node(
compute.block_analyzer.get_output_buffers(compute.block)[0].name)
# Loop through input tensors
for input_buffer in compute.block_analyzer.get_input_buffers(compute.block):
# Get the input node
input_node = self.traverse_dependencies(input_buffer)
input_node.add_next(node)
node.add_prev(input_node)
elif isinstance(compute, tir.Buffer):
node = self.get_or_create_node(compute.name)
return node
def analyze(self):
# Starting point for traversal
for _, compute in self.name2dep.items():
self.traverse_dependencies(compute)
def print_dependencies(self):
for name, node in self.mapping.items():
print(f"{name} depends on {', '.join([prev.name for prev in node._prev])}")
def find_path_from_source(self, start_name, target_name):
"""
Finds the path (if it exists) from a starting node (source) to a target node.
Returns the path as a list of nodes.
"""
visited = set()
path = []
if self._find_path_recursive(self.mapping[start_name], target_name, visited, path):
return path
return []
def _find_path_recursive(self, current_node, target_name, visited, path):
"""
Recursive helper function for find_path_from_source.
"""
if current_node.name == target_name:
path.append(current_node)
return True
if current_node.name in visited:
return False
visited.add(current_node.name)
path.append(current_node)
for next_node in current_node._next:
if self._find_path_recursive(next_node, target_name, visited, path):
return True
path.pop()
return False
class InputShapeInference:
def __init__(self, deps: List[Statement]):
self.deps = deps
self.target_mapping = {}
self.buffer_mapping = {}
self.reduce_axes = []
for dep in self.deps:
for ax in dep.block_analyzer.get_reduce_axis(dep.block):
self.reduce_axes.append(ax)
self.dep_analysis = DependencyAnalysis(self.deps)
self.dep_analysis.analyze()
def construct_dependency_target(self, targets: Tuple[str]):
if targets in self.target_mapping:
return self.target_mapping[targets]
# should be buffer name instead of block name
name2dep = {
dep.block_analyzer.get_output_buffers(dep.block)[0].name: dep for dep in self.deps
}
mapping = {}
input_vars = []
for target in targets:
vars = [
iter.var
for iter in name2dep[target].block_analyzer.get_spatial_axis(name2dep[target].block)
]
input_vars.append(vars)
mapping[target] = [vars]
ana = arith.Analyzer()
for dep in self.deps:
for name in dep.dependent_region:
if name not in mapping:
continue
dep_name = dep.dep_name
indices = mapping[name][0]
output_indices = dep.make_reverse(name, indices)
if dep_name in targets:
continue
if dep_name not in mapping:
mapping[dep_name] = [output_indices]
elif not region_exist_in_list(output_indices, mapping[dep_name]):
mapping[dep_name].append(output_indices)
for dep in reversed(self.deps):
indices_list = mapping[dep.dep_name]
ax_vars = [iter.var for iter in dep.block_analyzer.get_spatial_axis(dep.block)]
for input_name, regions in dep.dependent_region.items():
if input_name in targets:
continue
if input_name not in mapping:
mapping[input_name] = []
for indices in indices_list:
for region in regions:
vmap = {
k: (tir.Cast(k.dtype, v) if v.dtype != k.dtype else v)
for k, v in zip(ax_vars, indices)
}
region = [
ana.simplify(tir.stmt_functor.substitute(ax, vmap)) for ax in region
]
if not region_exist_in_list(region, mapping[input_name]):
mapping[input_name].append(region)
buffers = []
for dep in self.deps:
for buffer in dep.block_analyzer.get_buffers(dep.block):
buffers.append(buffer)
for buffer in buffers:
self.buffer_mapping[buffer.name] = buffer
self.target_mapping[targets] = input_vars, mapping
return input_vars, mapping
def infer(self,
shape: Dict[str, List[arith.ConstIntBound]],
rstep: Dict[str, int] = None,
targets=None):
if rstep is None:
rstep = {}
compute_targets = tuple(shape.keys())
input_vars, mapping = self.construct_dependency_target(compute_targets)
ana = arith.Analyzer()
results = {}
intermediate_bind = {}
for vars, bounds in zip(input_vars, shape.values()):
for var, bound in zip(vars, bounds):
ana.update(var, bound, True)
for ax in self.reduce_axes:
# assume the dom.min is always 0, maybe we can extend the IterInfo to include the min value.
if ax.var.name in rstep:
bound = arith.ConstIntBound(
int(ax.dom.min), int(ax.dom.min + min(ax.dom.extent, rstep[ax.var.name]) - 1))
else:
bound = arith.ConstIntBound(int(ax.dom.min), int(ax.dom.min + ax.dom.extent - 1))
ana.update(ax.var, bound, True)
for name, regions in mapping.items():
if targets is not None and name not in targets:
continue
if compute_targets[0:1] == compute_targets:
(compute_target,) = compute_targets
path = self.dep_analysis.find_path_from_source(name, compute_target)
if len(path) > 2:
intermediate_nodes = path[1:-1]
for node in intermediate_nodes:
iters = mapping[node.name]
if len(iters) != len(regions) or len(iters) != 1:
continue
if len(*iters) != len(*regions):
break
regions = iters
intermediate_bind[name] = compute_target
for region in regions:
bound = [ana.const_int_bound(indice) for indice in region]
if name in results: # simply merge two bounds
bound = [_merge_two_bounds(x, y) for x, y in zip(results[name], bound)]
results[name] = bound
else:
for region in regions:
bound = [ana.const_int_bound(indice) for indice in region]
if name in results: # simply merge two bounds
bound = [_merge_two_bounds(x, y) for x, y in zip(results[name], bound)]
results[name] = bound
for name, bounds in results.items():
results[name] = [c.max_value - c.min_value + 1 for c in bounds]
return results, intermediate_bind
def get_input_exprs(self, output_exprs):
input_vars, mapping = self.construct_dependency_target(tuple(output_exprs.keys()))
ana = arith.Analyzer()
for ax in self.reduce_axes:
ana.bind(ax.var, 0)
vmap = {}
for vars, exprs in zip(input_vars, output_exprs.values()):
for var, expr in zip(vars, exprs):
if expr.dtype != var.dtype:
expr = tir.Cast(var.dtype, expr)
vmap[var] = expr
result = {}
for name, regions in mapping.items():
region = regions[0]
result[name] = [
ana.simplify(tir.stmt_functor.substitute(index, vmap)) for index in region
]
return result
def region_exist_in_list(a, list) -> bool:
def expr_is_same(a, b) -> bool:
if isinstance(a, tir.IntImm) and isinstance(b, tir.IntImm):
return a.value == b.value
return structural_equal(a, b)
def region_is_same(a, b) -> bool:
return all(expr_is_same(indice_a, indice_b) for indice_a, indice_b in zip(a, b))
return any([region_is_same(a, x) for x in list])
def walk_indice(expr):
if isinstance(expr, tir.expr.BinaryOpExpr):
a = walk_indice(expr.a)
b = walk_indice(expr.b)
if a is not None and b is not None:
return expr
else:
return None
elif isinstance(expr, (tir.Var, tir.expr.ConstExpr)):
return expr
elif isinstance(expr, tir.ProducerLoad):
return None
elif isinstance(expr, tir.Cast):
a = walk_indice(expr.value)
if a is not None:
return expr
return None
elif isinstance(expr, tir.Call):
return None
else:
raise Exception("Unhandled node type in walk_indice(): %s" % expr)
def _extract_dependent_region(block_analyzer, block: BlockRV) -> Dict[str, List[tir.PrimExpr]]:
input_buffers = block_analyzer.get_input_buffers(block)
dependent_region = {buffer.name: [] for buffer in input_buffers}
def fvisit(x):
if not isinstance(x, tir.BufferLoad):
return
if x.buffer.name not in dependent_region:
return
index = []
for indice, shape_limit in zip(x.indices, x.buffer.shape):
expr = walk_indice(indice)
if expr is None:
expr = tir.Var("undefined", dtype="int8") % shape_limit
if isinstance(expr, tir.IntImm) and expr.value == 0:
"""for tensor ir zero dim smplification case.
for ax0, ax1, ax2 in T.grid(T.int64(1024), T.int64(1024), T.int64(1024)):
with T.block("T_dense"):
v0, v1, v2 = T.axis.remap("SSR", [ax0, ax1, ax2])
T.reads(A_reindex[T.int64(0), v0, v2], B_reindex[T.int64(0), v1, v2])
T.writes(T_dense_reindex[T.int64(0), v0, v1])
with T.init():
T_dense_reindex[T.int64(0), v0, v1] = T.float16(0)
T_dense_reindex[T.int64(0), v0, v1] = T_dense_reindex[T.int64(0), v0, v1] + A_reindex[T.int64(0), v0, v2] * B_reindex[T.int64(0), v1, v2]
For example, the T_dense_reindex has three dims, however there're only two spatial loops.
"""
continue
index.append(expr)
if not region_exist_in_list(index, dependent_region[x.buffer.name]):
dependent_region[x.buffer.name].append(index)
stmt = block_analyzer.sch.get(block)
tir.stmt_functor.post_order_visit(stmt, fvisit=fvisit)
return dependent_region
def get_analyzer_by_tir(block_analyzer, args) -> InputShapeInference:
deps = [Statement(block_analyzer, block) for block in args]
return InputShapeInference(deps)
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Template for the TileLang Carver."""
from .base import BaseTemplate # noqa: F401
from .matmul import MatmulTemplate # noqa: F401
from .gemv import GEMVTemplate # noqa: F401
from .elementwise import ElementwiseTemplate # noqa: F401
from .general_reduce import GeneralReductionTemplate # noqa: F401
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# Import necessary modules and classes
from abc import ABC, abstractmethod # For defining abstract base classes
from dataclasses import dataclass, field # For defining data classes
from ..arch import ( # Import architecture-related utilities and classes
TileDevice, is_volta_arch, is_ampere_arch, is_cdna_arch, auto_infer_current_arch)
from ..roller import Hint # Import the Hint class
from typing import List # For type hinting
from tvm.tir import PrimFunc # Import PrimFunc for handling tensor IR functions
@dataclass
class BaseTemplate(ABC):
"""
Base class template for hardware-aware configurations.
This serves as an abstract base class (ABC) that defines the structure
for subclasses implementing hardware-specific optimizations.
"""
# The architecture of the device, inferred automatically unless explicitly set
_arch: TileDevice = field(default=auto_infer_current_arch(), init=False, repr=False)
# The function associated with this template, initially None
_func: PrimFunc = field(default=None, init=False, repr=False)
@abstractmethod
def get_hardware_aware_configs(self, arch: TileDevice = None, topk: int = 10) -> List[Hint]:
"""
Abstract method that must be implemented by subclasses.
It should return a list of hardware-aware configurations (hints)
based on the specified architecture.
Args:
arch (TileDevice, optional): The target architecture. Defaults to None.
topk (int, optional): Number of top configurations to return. Defaults to 10.
Returns:
List[Hint]: A list of recommended hardware-aware configurations.
"""
pass
def with_arch(self, arch: TileDevice) -> "BaseTemplate":
"""
Sets the architecture for this template and returns itself.
Args:
arch (TileDevice): The architecture to set.
Returns:
BaseTemplate: The instance with the updated architecture.
"""
self._arch = arch
return self
def has_arch(self) -> bool:
"""
Checks whether the architecture is set.
Returns:
bool: True if the architecture is set, False otherwise.
"""
return self._arch is not None
def is_volta_arch(self) -> bool:
"""
Checks if the current architecture is a Volta architecture.
Returns:
bool: True if the architecture is Volta, False otherwise.
"""
return is_volta_arch(self._arch) if self._arch is not None else False
def is_ampere_arch(self) -> bool:
"""
Checks if the current architecture is an Ampere architecture.
Returns:
bool: True if the architecture is Ampere, False otherwise.
"""
return is_ampere_arch(self._arch) if self._arch is not None else False
def is_cdna_arch(self) -> bool:
"""
Checks if the current architecture is a CDNA architecture.
Returns:
bool: True if the architecture is CDNA, False otherwise.
"""
return is_cdna_arch(self._arch) if self._arch is not None else False
def equivalent_function(self) -> PrimFunc:
"""
Returns the function associated with this template.
Returns:
PrimFunc: The stored function.
"""
return self._func
def initialize_function(self) -> None:
"""
Placeholder method that should be implemented by subclasses.
This method is responsible for initializing the function.
Raises:
NotImplementedError: If not implemented in the subclass.
"""
raise NotImplementedError("initialize_function is not implemented")
def set_function(self, func: PrimFunc) -> "BaseTemplate":
"""
Sets the function for this template and returns itself.
Args:
func (PrimFunc): The function to associate with this template.
Returns:
BaseTemplate: The instance with the updated function.
"""
self._func = func
return self
def recommend_hints(self, topk: int = 10) -> List[Hint]:
"""
Provides a list of recommended hardware-aware configurations.
Args:
topk (int, optional): Number of top configurations to return. Defaults to 10.
Returns:
List[Hint]: A list of recommended configurations.
"""
return self.get_hardware_aware_configs(self._arch, topk)
@property
def arch(self) -> TileDevice:
"""
Returns the current architecture.
Returns:
TileDevice: The architecture of this template.
"""
return self._arch
def __post_init__(self):
"""
Post-initialization method that is called after the data class is created.
Ensures that the function is initialized.
"""
self.initialize_function()
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# Import necessary modules
from dataclasses import dataclass # Used for defining data classes
from .base import BaseTemplate # Importing the base class for templates
from tvm import te # Importing TVM's tensor expression module
from ..arch import TileDevice # Importing TileDevice for hardware-specific configurations
from ..roller import Hint # Importing Hint for optimization hints
from typing import List # Importing List type hint
from ..utils import get_roller_hints_from_func # Function to obtain optimization hints
@dataclass
class ElementwiseTemplate(BaseTemplate):
"""
A template for element-wise operations using TVM.
Attributes:
shape (List[int]): The shape of the tensor.
dtype (str): The data type of the tensor (default: "float16").
"""
# OP Related Config
shape: List[int] = None # Shape of the tensor
dtype: str = "float16" # Data type of the tensor
def get_hardware_aware_configs(self, arch: TileDevice = None, topk: int = 10) -> List[Hint]:
"""
Retrieves hardware-aware optimization configurations.
Args:
arch (TileDevice, optional): The target hardware architecture.
topk (int, optional): Number of top configurations to consider.
Returns:
List[Hint]: A list of optimization hints for the given architecture.
"""
roller_hints = get_roller_hints_from_func(self._func, arch=arch, topk=topk, allow_gemv=True)
return roller_hints
def initialize_function(self) -> None:
"""
Initializes the element-wise computation function.
Defines a simple element-wise computation: B = A + 1, where A is an input tensor.
The computation graph is built using TVM's tensor expressions.
"""
shape, dtype = self.shape, self.dtype # Extract shape and dtype
# Define a placeholder tensor A
A = te.placeholder(shape, name="A", dtype=dtype)
# Define the element-wise computation (adding 1 to each element)
def _compute_elementwise(*indices):
return A[indices] + 1
# Define the computation for B based on A
B = te.compute(
shape,
fcompute=_compute_elementwise, # Function that defines element-wise computation
name="B",
)
# Store input and output tensors as function arguments
args = [A, B]
# Create and set the computation function
self.set_function(te.create_prim_func(args))
def params_as_dict(self):
"""
Returns the parameters of the template as a dictionary.
Returns:
dict: A dictionary containing shape and dtype.
"""
return {"shape": self.shape, "dtype": self.dtype}
@property
def class_attributes(self):
"""
Returns class attributes as a dictionary.
Returns:
dict: A dictionary representation of the class attributes.
"""
return self.params_as_dict()
def __repr__(self) -> str:
"""
Returns a string representation of the object.
Returns:
str: A string describing the instance with its parameters.
"""
cls_name = self.__class__.__name__
fields = self.class_attributes
field_str = ", ".join(f"{key}={value!r}" for key, value in fields.items())
return f"{cls_name}({field_str})"
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from dataclasses import dataclass
from .base import BaseTemplate
from tvm import te
from ..arch import TileDevice
from ..roller import Hint
from typing import List
from ..utils import get_roller_hints_from_func
@dataclass
class GEMVTemplate(BaseTemplate):
"""
A template for Generalized Matrix-Vector Multiplication (GEMV).
This template defines the computation for a matrix-vector multiplication
with configurable parameters such as transposition, data types, and bias addition.
"""
# Operation-related configuration parameters
N: int = None # Number of columns in matrix B (output width)
K: int = None # Number of rows in matrix B (input width)
trans_B: bool = True # Whether to transpose matrix B
in_dtype: str = "float16" # Input data type
out_dtype: str = "float16" # Output data type
accum_dtype: str = "float16" # Accumulation data type
with_bias: bool = False # Whether to add a bias term
def get_hardware_aware_configs(self, arch: TileDevice = None, topk: int = 10) -> List[Hint]:
"""
Retrieves optimized hardware-aware configurations.
Args:
arch (TileDevice, optional): The target hardware architecture.
topk (int, optional): Number of top configurations to consider.
Returns:
List[Hint]: A list of optimization hints for hardware acceleration.
"""
roller_hints = get_roller_hints_from_func(self._func, arch=arch, topk=topk)
return roller_hints
def initialize_function(self) -> None:
"""
Defines and initializes the GEMV computation function.
This method sets up placeholders for input matrices, computes
the matrix-vector multiplication using TVM's compute API,
and optionally applies bias and type casting.
"""
M: int = 1 # Fixed M value, representing a single batch dimension
N, K = self.N, self.K
# Ensure M, N, K are valid positive integers
assert (isinstance(M, int) and isinstance(N, int) and
isinstance(K, int)), "Only Support Integer M, N, K"
assert (M > 0 and N > 0 and K > 0), "M, N, K should be positive"
# Load configuration parameters
trans_B = self.trans_B
in_dtype, out_dtype, accum_dtype = self.in_dtype, self.out_dtype, self.accum_dtype
with_bias = self.with_bias
# Define tensor shapes
input_shape = (M, K) # Shape of input matrix A
weight_shape = (K, N) if not trans_B else (N, K) # Shape of weight matrix B
output_shape = (M, N) # Shape of output matrix C
Bias_shape = (N,) # Shape of bias vector
# Create TVM placeholders for input tensors
A = te.placeholder(input_shape, name="A", dtype=in_dtype) # Input matrix
B = te.placeholder(weight_shape, name="B", dtype=in_dtype) # Weight matrix
Bias = te.placeholder(Bias_shape, name="Bias", dtype=accum_dtype) # Bias vector
# Define a reduction axis for matrix multiplication
k = te.reduce_axis((0, K), name="k")
def _compute_matmul(i, j):
"""
Compute function for matrix-vector multiplication.
Args:
i (int): Row index.
j (int): Column index.
Returns:
Computed value for C[i, j] as a sum over the reduction axis.
"""
A_indices = [i, k]
B_indices = [k, j] if not trans_B else [j, k]
return te.sum(
A[tuple(A_indices)].astype(accum_dtype) * B[tuple(B_indices)].astype(accum_dtype),
axis=k)
# Compute matrix multiplication result
C = te.compute(
output_shape,
fcompute=_compute_matmul,
name="C",
)
# Optionally apply bias addition
if with_bias:
C = te.compute(
output_shape,
lambda i, j: C[i, j] + Bias[j],
name="Bias",
)
# Optionally cast the output to a different type
if out_dtype != accum_dtype:
C = te.compute(
output_shape,
lambda i, j: C[i, j].astype(out_dtype),
name="D",
)
# Set function arguments (including bias if used)
args = [A, B, Bias, C] if self.with_bias else [A, B, C]
self.set_function(te.create_prim_func(args))
def params_as_dict(self):
"""
Returns the template parameters as a dictionary.
Returns:
dict: Dictionary containing template parameter values.
"""
return {
"N": self.N,
"K": self.K,
"trans_B": self.trans_B,
"in_dtype": self.in_dtype,
"out_dtype": self.out_dtype,
"accum_dtype": self.accum_dtype,
"with_bias": self.with_bias,
}
@property
def class_attributes(self):
"""
Returns the class attributes in dictionary form.
Returns:
dict: Dictionary of class attributes.
"""
return self.params_as_dict()
def __repr__(self) -> str:
"""
Returns a string representation of the class instance.
Returns:
str: A formatted string representation of the class.
"""
cls_name = self.__class__.__name__
fields = self.class_attributes
field_str = ", ".join(f"{key}={value!r}" for key, value in fields.items())
return f"{cls_name}({field_str})"
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from dataclasses import dataclass
from .base import BaseTemplate
from tvm import te
from ..arch import TileDevice
from ..roller import Hint
from typing import List, Union
from ..utils import get_roller_hints_from_func
@dataclass
class GeneralReductionTemplate(BaseTemplate):
# OP Related Config
structure: Union[str, List[str]] = None
shape: List[int] = None
dtype: str = "float16"
def get_hardware_aware_configs(self, arch: TileDevice = None, topk: int = 10) -> List[Hint]:
roller_hints = get_roller_hints_from_func(
self._func, arch=arch, topk=topk, allow_gemv=False)
return roller_hints
def initialize_function(self) -> None:
"""
Parse the structure (e.g., 'SSR'), build the TVM compute definition
with the appropriate spatial and reduce axes, and store it in self._func.
"""
assert isinstance(self.structure, str), "Structure must be a string Currently."
if self.structure is None or self.shape is None:
raise ValueError("Must provide both `structure` and `shape`.")
if len(self.structure) != len(self.shape):
raise ValueError("`structure` length must match `shape` length.")
if not all(isinstance(s, int) and s > 0 for s in self.shape):
raise ValueError("All dimensions in `shape` must be positive integers.")
# Separate axes into spatial vs reduce
spatial_axes = []
reduce_axes = []
for i, axis_type in enumerate(self.structure):
if axis_type.upper() == 'S':
spatial_axes.append((i, self.shape[i]))
elif axis_type.upper() == 'R':
reduce_axes.append((i, self.shape[i]))
else:
raise ValueError(f"Unrecognized axis type '{axis_type}', only 'S'/'R' allowed.")
# Create input placeholder
A = te.placeholder(shape=self.shape, dtype=self.dtype, name="A")
# Build a list of te.reduce_axis (for R) and the final output shape (for S).
# We'll index them in order so that the compute lambda is consistent.
# Example for SSR => 2 spatial dims (i, j), 1 reduce dim (k).
# (1) Prepare the spatial dimensions:
# The output shape is the product of all spatial axes in the same order they appear.
# We'll construct a tuple for the final te.compute's shape. Example: (i, j).
spatial_extents = [ext for (_, ext) in spatial_axes]
# (2) Prepare reduce axes
# e.g. (k0, (0, extent)), (k1, (0, extent)), ...
reduce_axis_objs = []
for _, ext in reduce_axes:
reduce_axis_objs.append(te.reduce_axis((0, ext)))
# We need to build a function that uses the correct index mapping.
# Let's define a small helper that maps from the "spatial" indices to the
# correct A[] indexing, and includes the reduce axes as well.
# The final compute's shape is precisely the number of spatial axes in the same order.
out_shape = tuple(spatial_extents)
# We'll create a lambda of the form:
# (i, j, ...) -> te.sum(A[i, j, k, ...], axis=[k, ...])
# We can do this dynamically by constructing indexing for each dimension in `A`.
def compute_func(*spatial_indices):
# spatial_indices is a tuple of the same length as spatial_axes
# We must place each spatial index into the correct dimension of `A`
# or reduce_axis. Then for the reduce axes, we use the reduce_axis_objs in order.
# We want to build a full indexing that has length = len(self.shape).
# E.g. structure='SSR', shape=[S0, S1, R2]
# i, j -> A[i, j, k]
# where i = spatial_indices[0], j = spatial_indices[1]
full_index = []
spatial_iter = 0
reduce_iter = 0
# Walk through the structure in order
for axis_type in self.structure:
if axis_type.upper() == 'S':
# use the next spatial_indices item
full_index.append(spatial_indices[spatial_iter])
spatial_iter += 1
else:
# axis_type is 'R', use the next reduce_axis_obj
full_index.append(reduce_axis_objs[reduce_iter])
reduce_iter += 1
# Now we do the sum:
return te.sum(A[tuple(full_index)], axis=tuple(reduce_axis_objs))
# Construct the output tensor with te.compute
C = te.compute(out_shape, compute_func, name="C")
# Create a PrimFunc from placeholders + output
args = [A, C]
prim_func = te.create_prim_func(args)
self.set_function(prim_func)
def params_as_dict(self):
return {"shape": self.shape, "dtype": self.dtype}
@property
def class_attributes(self):
return self.params_as_dict()
def __repr__(self) -> str:
cls_name = self.__class__.__name__
fields = self.class_attributes
field_str = ", ".join(f"{key}={value!r}" for key, value in fields.items())
return f"{cls_name}({field_str})"
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment