Commit 1ef644e7 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[CostModel][Carver] Support Hint Recommend for Shared memory Kernel Fusion (#73)

* [Enhancement] Add VectorizeLoop function and update imports for compatibility

* [CI][Test] Improve test cases for vectorization and fix typos in parser comments

* lint fix

* Fix incorrect module reference for VectorizeLoop transformation

* Refactor vectorize_loop transformation by removing unused extent mutation logic

* [Enhancement] Add support for FP8 data types and global barriers in CUDA codegen

* Fix formatting in CUDA FP8 header file for consistency

* Refactor CI workflow to use 'tilelang_ci' virtual environment and update CUDA type printing for better clarity

* Update submodule 'tvm' to latest commit for improved functionality

* Refactor execution backend references from 'dl_pack' to 'dlpack' for consistency and clarity; add apply_simplify function to simplify PrimFunc or IRModule.

* Refactor CUDA code for improved readability; clean up formatting and remove unnecessary whitespace in multiple files.

* Refactor import statement in test_tilelang_kernel_dequantize_gemm.py to use 'tilelang.language' for consistency

* Add CUDA requirements to FP8 test cases and update references for clarity

* Add a blank line for improved readability in test_tilelang_kernel_fp8_gemm_mma.py

* Fix data type in reference result calculation for consistency in test_tilelang_kernel_gemm_mma_intrinsic.py

* Add CUDA requirements and FP8 test cases for matmul and gemv simulations

* Remove debug print statements and use tilelang's testing assertion for result validation in test_tilelang_kernel_gemm_mma_intrinsic.py

* Remove outdated comment regarding FP8 tests in test_tilelang_kernel_gemv_simt.py

* Add BF16 support to matrix multiplication and introduce corresponding test cases

* Add a blank line for improved readability in BF16 GEMM test

* Update acknowledgements in README to include supervision by Zhi Yang at Peking University

* enhance acknowledgement

* Replace tutorial on memory layout optimization with new tutorial on writing high-performance kernels with thread primitives

* Update subproject commit for TVM dependency

* Update subproject commit for TVM dependency

* Add int4_t type and functions for packing char values in CUDA common header

* Add plot_layout example and implement GetForwardVars method in layout classes

* Refactor code for improved readability by adjusting line breaks and formatting in layout and test files

* Fix formatting by removing unnecessary line break in layout.h

* Refactor make_int4 function for improved readability by adjusting parameter formatting

* Add legend to plot_layout for improved clarity of thread and local IDs

* Remove unnecessary dependencies from requirements files for cleaner setup

* Remove flash_mha.py and add .gitkeep to deepseek_mla directory

* Add build requirements and update installation scripts for improved setup

* Introduce carver

* Refactor imports and improve code formatting for consistency

* Add unit tests for carver recommendation hints

* lint fix

* Enhance ElementwiseTemplate and BaseTemplate with detailed docstrings for improved code documentation and clarity

* Refactor import statements and clean up whitespace in template files for improved readability

* Add README.md for Carver framework with usage examples and architecture support

* Refactor import statement in matmul_analysis.py for consistency

* Refactor TileDict and TensorCorePolicy methods for improved clarity and functionality

* Add tests for general matrix multiplication emit configurations

* Refactor formatting in test_tilelang_carver_generate_hints.py for improved readability

* Add FlashAttentionTemplate and related functionality for hint recommendations

* Refactor whitespace in FlashAttentionTemplate and test_tilelang_carver_recommend_hints for improved readability
parent 465f0107
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import tilelang.testing
from tilelang import carver
from tilelang.carver.roller import PrimFuncNode, OutputNode, Edge
from tilelang.carver.arch import auto_infer_current_arch
from tvm import te
def run_general_matmul_emit_configs(M, N, K, topk: int = 20):
arch = auto_infer_current_arch()
def gemm(M, N, K):
A = te.placeholder((M, K), name='A', dtype='float16')
B = te.placeholder((N, K), name='B', dtype='float16')
# Describe the matrix multiplication in TE
k = te.reduce_axis((0, K), name='k')
C = te.compute(
(M, N),
lambda i, j: te.sum(A[i, k].astype('float16') * B[j, k].astype('float16'), axis=[k]),
name='C')
return A, B, C
arg1 = gemm(M, N, K)
args = arg1
func = te.create_prim_func(args)
tensorized_func, tags = carver.utils.get_tensorized_func_and_tags(func, arch.target)
print(tags)
policy = carver.TensorCorePolicy.from_prim_func(
func=tensorized_func, arch=arch, tags=tags, name="matmul_0")
hints = policy.emit_config(topk=topk)
for hint in hints:
print(hint)
assert len(hints) > 0, "Hints length is zero"
prim_func_node = PrimFuncNode(tensorized_func, name="matmul_1")
output_nodes = [OutputNode(prim_func_node)]
policy = carver.TensorCorePolicy.from_output_nodes(output_nodes, arch=arch, tags=tags)
hints = policy.emit_config(topk=10)
for config in hints:
print(config)
assert len(hints) > 0, "Hints length is zero"
def test_general_matmul_emit_configs():
run_general_matmul_emit_configs(128, 128, 128)
def run_general_matmul_matmul_emit_configs(M, N, K, topk: int = 20):
arch = auto_infer_current_arch()
def gemm(M, N, K):
A = te.placeholder((M, K), name='A', dtype='float16')
B = te.placeholder((N, K), name='B', dtype='float16')
# Describe the matrix multiplication in TE
k = te.reduce_axis((0, K), name='k')
C = te.compute(
(M, N),
lambda i, j: te.sum(A[i, k].astype('float16') * B[j, k].astype('float16'), axis=[k]),
name='C')
return A, B, C
arg1 = gemm(M, N, K)
args = arg1
func = te.create_prim_func(args)
tensorized_func, tags = carver.utils.get_tensorized_func_and_tags(func, arch.target)
print(tags)
node_0 = PrimFuncNode(tensorized_func, name="matmul_0")
node_1 = PrimFuncNode(tensorized_func, name="matmul_1")
edge = Edge(node_0, node_1, 0, 0)
node_0._out_edges.append(edge)
node_1.set_inputs(0, edge)
output_nodes = [OutputNode(node_1)]
policy = carver.TensorCorePolicy.from_output_nodes(output_nodes, arch=arch, tags=tags)
hints = policy.emit_config(topk=topk)
for config in hints:
print(config)
assert len(hints) > 0, "Hints length is zero"
def test_general_matmul_matmul_emit_configs():
run_general_matmul_matmul_emit_configs(128, 128, 128)
if __name__ == "__main__":
tilelang.testing.main()
......@@ -110,5 +110,41 @@ def test_gemv_recommend_hints():
run_gemv_recommend_hints(1024, 1024, "float16", "float32", "float16")
def run_fmha_recommend_hints(
batch_size: int = 4,
num_heads: int = 32,
seq_length: int = 512,
seq_kv_length: int = 512,
head_dim: int = 128,
in_dtype: str = "float16",
accum_dtype: str = "float16",
out_dtype: str = "float16",
):
arch = auto_infer_current_arch()
carve_template = carver.FlashAttentionTemplate(
batch_size=batch_size,
num_heads=num_heads,
seq_length=seq_length,
seq_kv_length=seq_kv_length,
head_dim=head_dim,
in_dtype=in_dtype,
accum_dtype=accum_dtype,
out_dtype=out_dtype,
).with_arch(arch)
func = carve_template.equivalent_function()
assert func is not None, "Function is None"
hints = carve_template.recommend_hints(topk=20)
for hint in hints:
print(hint)
assert len(hints) > 0, "Hints length should be greater than 0"
def test_fmha_recommend_hints():
run_fmha_recommend_hints(4, 32, 512, 512, 128, "float16", "float16", "float16")
run_fmha_recommend_hints(4, 32, 512, 512, 128, "int8", "int32", "int32")
if __name__ == "__main__":
tilelang.testing.main()
......@@ -13,4 +13,4 @@ from .analysis import (
from .common_schedules import get_block, get_output_blocks, try_inline, try_inline_contiguous_spatial # noqa: F401
from .roller import *
from .arch import CUDA, CDNA # noqa: F401
from .template import MatmulTemplate, GEMVTemplate, ElementwiseTemplate, GeneralReductionTemplate # noqa: F401
from .template import MatmulTemplate, GEMVTemplate, ElementwiseTemplate, GeneralReductionTemplate, FlashAttentionTemplate # noqa: F401
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from .node import PrimFuncNode # noqa: F401
from .node import PrimFuncNode, OutputNode, Edge # noqa: F401
from .rasterization import NoRasterization, Rasterization2DRow, Rasterization2DColumn # noqa: F401
from .hint import Hint # noqa: F401
from .policy import DefaultPolicy, TensorCorePolicy # noqa: F401
......
......@@ -33,7 +33,8 @@ class BestFit:
size = (size + self.align - 1) // self.align * self.align
found = None
for block in self.list:
if block.is_free and block.size() >= size and not found or found.size() > block.size():
if block.is_free and block.size() >= size and (not found or
found.size() > block.size()):
found = block
if found:
found.is_free = False
......
......@@ -99,8 +99,8 @@ class TileDict:
def get_tile(self, func) -> List[int]:
return self.tile_map[func]
def get_rstep(self, func) -> Dict[str, int]:
return self.rstep_map
def get_rstep(self, node) -> Dict[str, int]:
return self.rstep_map[node]
def __hash__(self) -> int:
return hash(tuple(self.output_tile))
......
......@@ -13,6 +13,7 @@ from ..analysis import BlockInfo, get_reduction_blocks
from .. import analysis
from .. import normalize_prim_func
from .shape_inference import get_analyzer_by_tir
from dataclasses import dataclass
def pre_order_traverse(block_analyzer, blocks, func):
......@@ -83,13 +84,28 @@ class BlockAnalyzer(object):
return self.sch.get_consumers(block)
@dataclass
class Edge:
src_node: 'Node'
dst_node: 'Node'
src_id: int
dst_id: int
class Node(object):
def __init__(self, tags: Optional[Dict] = None) -> None:
def __init__(self, tags: Optional[Dict] = None, name: str = "Node") -> None:
self.name = name
if tags is None:
tags = {}
self._out_edges = []
self._in_edges = []
self._shapes = []
self._dtypes = []
self._tag: Dict = {}
self.update_tags(tags)
def update_tags(self, tags: Dict) -> None:
for tag in tags:
self.add_tag(tag, tags[tag])
......@@ -104,11 +120,82 @@ class Node(object):
return None
return self._tag[k]
def is_placeholder(self):
return False
def is_output(self):
return False
@property
def inputs(self) -> List[Edge]:
return self._in_edges
@property
def outputs(self) -> List[Edge]:
return self._out_edges
def set_inputs(self, i: int, edge: Edge):
assert i < len(self._in_edges)
self._in_edges[i] = edge
def set_outputs(self, i: int, edge: Edge):
assert i < len(self._out_edges)
self._out_edges[i] = edge
def get_dtype(self, id=0) -> tvm.DataType:
return self._dtypes[id]
def set_dtype(self, dtype: tvm.DataType, id=0) -> None:
assert isinstance(dtype, tvm.DataType), type(dtype)
if dtype == tvm.DataType("bool"):
dtype = tvm.DataType("int8")
if len(self._dtypes) <= id:
self._dtypes.extend([None for _ in range(id - len(self._dtypes) + 1)])
elif self._dtypes[id] is not None:
assert self._dtypes[id] == dtype, (self._dtypes, dtype)
self._dtypes[id] = dtype
def get_shape(self, id: int = 0) -> List[int]:
return self._shapes[id]
def set_shape(self, shape: List[int], id=0, overwrite=False) -> None:
if len(self._shapes) <= id:
self._shapes.extend([None for _ in range(id - len(self._shapes) + 1)])
# elif self._shapes[id] is not None and not overwrite:
# assert self._shapes[id] == list(map(int, shape)), (self._shapes, list(map(int, shape)))
self._shapes[id] = list(map(int, shape))
def num_outputs(self) -> int:
if len(self.outputs) == 0:
return 0
return max([e.src_id for e in self.outputs]) + 1
def get_ir(self) -> str:
raise NotImplementedError()
def __repr__(self) -> str:
return "<Node, " + self.name + ">"
class PlaceHolderNode(Node):
def __init__(self, name=""):
super().__init__(name="PlaceHolder_" + name)
def is_placeholder(self):
return True
def get_ir(self) -> str:
return "placeholder"
class PrimFuncNode(Node):
def __init__(self, prim_func: PrimFunc, tags: Optional[Dict] = None) -> None:
super().__init__(tags)
def __init__(self,
prim_func: PrimFunc,
tags: Optional[Dict] = None,
name: str = "PrimFuncNode") -> None:
super().__init__(tags, name=name)
self.prim_func = self._specialize_func(prim_func)
self.sch: tir.Schedule = tir.Schedule(self.prim_func)
self.block_analyzer: BlockAnalyzer = BlockAnalyzer(self.sch)
......@@ -122,8 +209,31 @@ class PrimFuncNode(Node):
self.buffers = []
self.args = []
self._analysis_funcinfo()
self._assign_placeholder_node()
self.ana = get_analyzer_by_tir(self.block_analyzer, self.blocks)
# set input shapes and dtypes
for edge, buffer in zip(self.inputs, self.input_buffers):
edge.src_node.set_shape(buffer.shape, edge.src_id)
edge.src_node.set_dtype(tvm.DataType(buffer.dtype), edge.src_id)
for output_id, buffer in enumerate(self.output_buffers):
self.set_shape(buffer.shape, output_id)
self.set_dtype(tvm.DataType(buffer.dtype), output_id)
def _assign_placeholder_node(self):
inputs: List[Node] = []
for buffer in self.input_buffers:
inputs.append(PlaceHolderNode(buffer.name))
for dst_id, n in enumerate(inputs):
if isinstance(n, Node):
n = (n, 0)
assert (len(n) == 2)
src_node, src_id = n[0], n[1]
edge = Edge(src_node, self, src_id, dst_id)
self._in_edges.append(edge)
src_node._out_edges.append(edge)
def _specialize_func(self, func: PrimFunc):
# Specialize the function to make it more friendly for analysis.
# set attrs
......@@ -222,9 +332,6 @@ class PrimFuncNode(Node):
assert self._dtypes[id] == dtype, (self._dtypes, dtype)
self._dtypes[id] = dtype
def get_dtype(self, id=0) -> tvm.DataType:
return self._dtypes[id]
def get_buffer_dtype(self, buffer: tir.Buffer) -> tvm.DataType:
return tvm.DataType(buffer.dtype)
......@@ -407,3 +514,97 @@ class PrimFuncNode(Node):
def get_input_buffers(self) -> List[tir.Buffer]:
return self.block_analyzer.input_buffers
class OutputNode(Node):
def __init__(self, node, id=0):
super().__init__(name="OutputNode")
# connect node and output node
assert isinstance(node, PrimFuncNode), "OutputNode should connect to PrimFuncNode"
# initialize edge and connect
src_node, src_id = node, id
edge = Edge(src_node, self, src_id, 0)
self._in_edges.append(edge)
src_node._out_edges.append(edge)
self.set_shape(node.get_shape(id))
self.set_dtype(node.get_dtype(id))
def is_output(self):
return True
def get_ir(self) -> str:
return "output"
def topo_order(list_of_nodes) -> List[Node]:
input_ready_count = {node: len(node.inputs) for node in list_of_nodes}
ready = list(filter(lambda node: input_ready_count[node] == 0, list_of_nodes))
output_list = []
while len(ready) > 0:
node = ready.pop(0)
output_list.append(node)
for edge in node.outputs:
dst_node = edge.dst_node
if dst_node not in input_ready_count:
input_ready_count[dst_node] = len(dst_node.inputs)
list_of_nodes.append(dst_node)
input_ready_count[dst_node] -= 1
assert (input_ready_count[dst_node] >= 0)
if input_ready_count[dst_node] == 0:
ready.append(dst_node)
assert (len(list_of_nodes) == len(output_list))
return output_list
def find_topo_sort_priority(output_node_list) -> List[Node]:
import sys
sys.setrecursionlimit(10000)
def topo_sort_get_layer(node, topo_layer):
if node in topo_layer:
return
topo_layer[node] = 0
for edge in node.inputs:
topo_sort_get_layer(edge.src_node, topo_layer)
topo_layer[node] = max(topo_layer[node], topo_layer[edge.src_node] + 1)
topo_layer = {}
for node in output_node_list:
topo_sort_get_layer(node, topo_layer)
def topo_sort_dfs(node, visited, topo_order):
if node in visited:
return
visited.add(node)
ordered_input_nodes = sorted([edge.src_node for edge in node.inputs],
key=lambda n: topo_layer[n],
reverse=True)
for n in ordered_input_nodes:
topo_sort_dfs(n, visited, topo_order)
topo_order.append(node)
visited = set()
topo_order = []
for node in output_node_list:
topo_sort_dfs(node, visited, topo_order)
return topo_order
def find_topo_sort(output_node_list) -> List[Node]:
def topo_sort_dfs(node, visited, topo_order):
if node in visited:
return
visited.add(node)
for edge in node.inputs:
topo_sort_dfs(edge.src_node, visited, topo_order)
topo_order.append(node)
visited = set()
topo_order = []
for node in output_node_list:
topo_sort_dfs(node, visited, topo_order)
return topo_order
......@@ -13,7 +13,7 @@ from ...arch import TileDevice
from ..bestfit import BestFit
from ..hint import Hint, Stride, TileDict
from .common import coalesced_factor, coalesced_tensor_shape, factorize, get_all_factors
from ..node import PrimFuncNode
from ..node import PrimFuncNode, OutputNode, find_topo_sort
from ..rasterization import NoRasterization
......@@ -23,23 +23,69 @@ class DefaultPolicy:
minimize memory traffic and maximize parallelism.for BitBLAS Schedule.
"""
def __init__(self,
func: tvm.tir.PrimFunc,
arch: TileDevice,
tags: Optional[Dict] = None) -> None:
func: tvm.tir.PrimFunc
nodes: List[PrimFuncNode] = []
arch: TileDevice
tags: Dict
def __init__(self, arch: TileDevice, tags: Optional[Dict] = None) -> None:
if tags is None:
tags = {}
self.arch = arch
self.prim_func_node = PrimFuncNode(func, tags)
self.ordered_nodes = [self.prim_func_node]
self.output_nodes = [self.prim_func_node]
self.tags = tags
self.rasterization = NoRasterization()
@classmethod
def from_prim_func(cls,
func: tvm.tir.PrimFunc,
arch: TileDevice,
tags: Optional[Dict] = None,
name: str = "PrimFuncNode"):
return cls(arch, tags)._init_with_prim_func(func, name)
@classmethod
def from_output_nodes(cls,
nodes: List[OutputNode],
arch: TileDevice,
tags: Optional[Dict] = None):
return cls(arch, tags)._init_with_output_nodes(nodes)
def _init_with_prim_func(self,
func: tvm.tir.PrimFunc,
name: str = "PrimFuncNode") -> "DefaultPolicy":
if func is not None and isinstance(func, tvm.tir.PrimFunc):
self.func = func
self.prim_func_node = PrimFuncNode(self.func, tags=self.tags, name=name)
else:
raise NotImplementedError("Only support PrimFunc for now")
output_nodes = [OutputNode(self.prim_func_node)]
self._init_with_output_nodes(output_nodes)
return self
def _init_with_output_nodes(self, output_nodes: List[OutputNode]):
self.ordered_nodes = list(
filter(lambda n: not n.is_placeholder() and not n.is_output(),
find_topo_sort(output_nodes)))
for node in self.ordered_nodes:
node.update_tags(self.tags)
self.output_nodes = []
for node in self.ordered_nodes:
is_topo_output = True
for edge in node.outputs:
if not edge.dst_node.is_output():
is_topo_output = False
if is_topo_output:
self.output_nodes.append(node)
return self
def emit_config(self, topk: int) -> List[Hint]:
base_tile = self.get_base_tile()
if base_tile is None:
return []
rstep_map = self._assign_reduce_step(self.prim_func_node)
rstep_map = {node: self._assign_reduce_step(node) for node in self.ordered_nodes}
smem_tile_condidates = self.dfs_smem_tile(base_tile, rstep_map)
results = []
for td in smem_tile_condidates:
......@@ -56,7 +102,7 @@ class DefaultPolicy:
return results
def dfs_smem_tile(self, init_tile, rstep_map) -> Iterable[TileDict]:
_steps = [get_all_factors(n) for n in self.prim_func_node.get_space_dim()]
_steps = [get_all_factors(n) for n in self.output_nodes[0].get_space_dim()]
steps = [step[step.index(t):] for step, t in zip(_steps, init_tile)]
for i in range(len(steps)):
added = list(
......@@ -104,8 +150,26 @@ class DefaultPolicy:
The base tile configuration, which is a list of 1s equal in length to the space dimensions
of the primary function node.
"""
shape = self.prim_func_node.get_space_dim()
if len(set([len(node.get_space_dim()) for node in self.output_nodes])) > 1:
# If output dim sizes are not same, don't know how to handle them
return None
out_node = self.output_nodes[0]
shape = out_node.get_space_dim()
base_tile = [1 for _ in shape]
wpi = self.compute_workload_per_item(base_tile)
for dim, n in enumerate(shape):
factors = [n]
for factor in factors:
if factor == base_tile[dim]:
continue
tile = base_tile.copy()
tile[dim] = factor
new_wpi = self.compute_workload_per_item(tile)
if new_wpi < wpi:
wpi, base_tile = new_wpi, tile
else:
break
return base_tile
......@@ -126,12 +190,25 @@ class DefaultPolicy:
based on the output nodes' space dimensions.
"""
tile_map = {}
tile_map[self.prim_func_node] = [
tile[i] * self.prim_func_node.get_space_dim()[i] //
self.output_nodes[0].get_space_dim()[i] for i in range(len(tile))
for node in self.output_nodes:
tile_map[node] = [
tile[i] * node.get_space_dim()[i] // self.output_nodes[0].get_space_dim()[i]
for i in range(len(tile))
]
return tile_map
def compute_workload_per_item(self, output_tile) -> float:
op_tile_map = self._get_output_tile_map(output_tile)
compute = 0
num_item = int(np.prod(output_tile))
for node in reversed(self.ordered_nodes):
tile = op_tile_map[node]
dep = node.propagate_inputs(tile)
compute += int(np.prod(tile))
for i, edge in enumerate(node.inputs):
op_tile_map[edge.src_node] = dep[i]
return float(compute / num_item)
def score_block_size(self, n):
"""
Scores a block size based on its efficiency and fit relative to the architecture's warp size and SM partition.
......@@ -312,7 +389,7 @@ class DefaultPolicy:
new_rstep_id = _enlarge(cur_rstep_id)
if new_rstep_id is None:
break
new_rstep_map = {
new_rstep_map[node] = {
k.var.name: all_steps[k.var.name][new_rstep_id[k.var.name]] for k in node.raxis
}
old_rstep_map = td.rstep_map
......@@ -328,8 +405,8 @@ class DefaultPolicy:
for node in self.ordered_nodes:
if len(node.raxis) > 0:
rstep = _optimize(node, rstep_map)
rstep_map = rstep
rstep = _optimize(node, rstep_map[node])
rstep_map[node] = rstep
td.rstep_map = rstep_map
td.smem_cost, td.cached_tensors_map = self._compute_shared_memory_usage(td)
......@@ -353,18 +430,21 @@ class DefaultPolicy:
tile = op_tile_map[node]
input_shapes = node.propagate_inputs(tile)
output_shapes = node.propagate_outputs(tile)
for i, buffer in enumerate(node.input_buffers):
nbytes = (node.get_buffer_dtype(buffer).bits + 7) // 8
for i, edge in enumerate(node.inputs):
op_tile_map[edge.src_node] = input_shapes[i]
if edge.src_node.is_placeholder():
nbytes = (edge.src_node.get_dtype().bits + 7) // 8
read_transaction_elements = self.arch.transaction_size[1] // nbytes
traffic += (
coalesced_tensor_shape(input_shapes[i], buffer.shape, read_transaction_elements)
* nbytes)
for i, buffer in enumerate(node.output_buffers):
nbytes = (node.get_buffer_dtype(buffer).bits + 7) // 8
traffic += coalesced_tensor_shape(input_shapes[i], edge.src_node.get_shape(),
read_transaction_elements) * nbytes
for edge in node.outputs:
if edge.dst_node.is_output():
nbytes = (edge.src_node.get_dtype().bits + 7) // 8
write_transaction_elements = self.arch.transaction_size[0] // nbytes
traffic += (
coalesced_tensor_shape(output_shapes[i], buffer.shape,
write_transaction_elements) * nbytes)
traffic += coalesced_tensor_shape(output_shapes[edge.src_id],
node.get_shape(edge.src_id),
write_transaction_elements) * nbytes
return traffic, op_tile_map
def infer_node_smem_usage(self, td: TileDict, node: PrimFuncNode):
......@@ -404,12 +484,32 @@ class DefaultPolicy:
self._compute_stride_map(td)
allocator = BestFit()
block_map = {}
processed = set()
cached_tensors_map = {}
node_internal_bytes, cached_tensors_map[self.prim_func_node] = self.infer_node_smem_usage(
td, self.prim_func_node)
def can_free(node, out_id):
for edge in node.outputs:
if edge.src_id == out_id and edge.dst_node not in processed:
return False
return True
for node in self.ordered_nodes:
node_internal_bytes, cached_tensors_map[node] = self.infer_node_smem_usage(td, node)
block = allocator.malloc(node_internal_bytes)
allocator.free(block)
# free inputs
processed.add(node)
for edge in node.inputs:
if not edge.src_node.is_placeholder() and can_free(edge.src_node, edge.src_id):
allocator.free(block_map.pop((edge.src_node, edge.src_id)))
# alloc outputs
for edge in node.outputs:
if not edge.dst_node.is_output() and (node, edge.src_id) not in block_map:
dtype_bytes = (node.get_dtype(edge.src_id).bits + 7) // 8
stride = td.output_strides_map[node][len(node.inputs) + edge.src_id]
output_elem = stride.compute_elements_from_shape(td.get_tile(node))
block_map[(node, edge.src_id)] = allocator.malloc(output_elem * dtype_bytes)
assert len(block_map) == 0
return allocator.limit, cached_tensors_map
......@@ -585,8 +685,9 @@ class DefaultPolicy:
for block_size in block_size_ordered:
result = {}
failed = False
result = self._assign_block_size(self.prim_func_node, td, block_size)
if result is None:
for node in self.ordered_nodes:
result[node] = self._assign_block_size(node, td, block_size)
if result[node] is None:
failed = True
break
if failed:
......@@ -678,7 +779,7 @@ class DefaultPolicy:
# Plan vectorize
codegen_dict.vectorize = self._plan_vectorize(node, td, block_size)
codegen_dict.arch = self.arch
codegen_dict.opt_shapes = self.prim_func_node.get_tag("opt_shapes")
codegen_dict.opt_shapes = node.get_tag("opt_shapes")
return codegen_dict
def _plan_vectorize(self, node: PrimFuncNode, td: TileDict, block_size: int):
......
......@@ -5,7 +5,6 @@ import tvm
from typing import Dict, List, Tuple, Optional
import numpy as np
import logging
from ...arch import TileDevice
from ..hint import Hint, Stride, TileDict, IntrinInfo
from ..node import PrimFuncNode
from .common import coalesced_factor, factorize, get_all_factors
......@@ -17,18 +16,17 @@ logger = logging.getLogger(__name__)
class TensorCorePolicy(DefaultPolicy):
def __init__(self,
func: tvm.tir.PrimFunc,
arch: TileDevice,
tags: Optional[Dict] = None) -> None:
super().__init__(func, arch, tags)
# this is the trick for wmma.
# However, for int8 mma, the wmma_k should be 32.
self.wmma_k = 16
self.pipeline_stage: int = 1
self.use_async_copy: bool = False
self.block_reduction_depth: Optional[int] = None
wmma_k: int = 16
pipeline_stage: int = 1
use_async_copy: bool = False
block_reduction_depth: Optional[int] = None
def _init_with_prim_func(self, func: tvm.tir.PrimFunc, name: Optional[str] = None):
super()._init_with_prim_func(func, name)
self._legalize_info()
return self
def _legalize_info(self):
pipleline_stage = self.prim_func_node.get_tag("pipeline_stage")
......@@ -184,8 +182,8 @@ class TensorCorePolicy(DefaultPolicy):
for node in self.ordered_nodes:
if len(node.raxis) > 0:
rstep = _optimize(node, rstep_map)
rstep_map = rstep
rstep = _optimize(node, rstep_map[node])
rstep_map[node] = rstep
td.rstep_map = rstep_map
td.smem_cost, td.cached_tensors_map = self._compute_shared_memory_usage(td)
......@@ -335,9 +333,9 @@ class TensorCorePolicy(DefaultPolicy):
codegen_dict.shared_scope = "shared.dyn"
codegen_dict.complete_config(node)
codegen_dict.vectorize = self._plan_vectorize(self.prim_func_node, td, block_size)
codegen_dict.vectorize = self._plan_vectorize(node, td, block_size)
codegen_dict.arch = self.arch
codegen_dict.opt_shapes = self.prim_func_node.get_tag("opt_shapes")
codegen_dict.opt_shapes = node.get_tag("opt_shapes")
codegen_dict.tensorcore_legalization()
return codegen_dict
......
......@@ -7,3 +7,4 @@ from .matmul import MatmulTemplate # noqa: F401
from .gemv import GEMVTemplate # noqa: F401
from .elementwise import ElementwiseTemplate # noqa: F401
from .general_reduce import GeneralReductionTemplate # noqa: F401
from .flashattention import FlashAttentionTemplate # noqa: F401
......@@ -6,7 +6,8 @@ from abc import ABC, abstractmethod # For defining abstract base classes
from dataclasses import dataclass, field # For defining data classes
from ..arch import ( # Import architecture-related utilities and classes
TileDevice, is_volta_arch, is_ampere_arch, is_cdna_arch, auto_infer_current_arch)
from ..roller import Hint # Import the Hint class
from ..roller.hint import Hint # Import the Hint class
from ..roller.node import OutputNode # Import the OutputNode class
from typing import List # For type hinting
from tvm.tir import PrimFunc # Import PrimFunc for handling tensor IR functions
......@@ -25,6 +26,9 @@ class BaseTemplate(ABC):
# The function associated with this template, initially None
_func: PrimFunc = field(default=None, init=False, repr=False)
# The outputs nodes associated with this template, initially None
_output_nodes: List[OutputNode] = field(default=None, init=False, repr=False)
@abstractmethod
def get_hardware_aware_configs(self, arch: TileDevice = None, topk: int = 10) -> List[Hint]:
"""
......@@ -122,6 +126,19 @@ class BaseTemplate(ABC):
self._func = func
return self
def set_output_nodes(self, output_nodes: List[OutputNode]) -> "BaseTemplate":
"""
Sets the output nodes for this template and returns itself.
Args:
output_nodes (List[OutputNode]): The output nodes to associate with this template.
Returns:
BaseTemplate: The instance with the updated output nodes.
"""
self._output_nodes = output_nodes
return self
def recommend_hints(self, topk: int = 10) -> List[Hint]:
"""
Provides a list of recommended hardware-aware configurations.
......@@ -144,6 +161,16 @@ class BaseTemplate(ABC):
"""
return self._arch
@property
def output_nodes(self) -> List[OutputNode]:
"""
Returns the output nodes associated with this template.
Returns:
List[OutputNode]: The output nodes.
"""
return self._output_nodes
def __post_init__(self):
"""
Post-initialization method that is called after the data class is created.
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from dataclasses import dataclass
from .base import BaseTemplate
from tvm import te
from ..arch import TileDevice
from ..roller import Hint
from ..roller import PrimFuncNode, OutputNode, Edge
from typing import List
from ..utils import get_roller_hints_from_output_nodes, get_tensorized_func_and_tags
@dataclass
class FlashAttentionTemplate(BaseTemplate):
_output_nodes: List[OutputNode] = None
# Operation-related configuration parameters
batch_size: int = 1
num_heads: int = 1
head_dim: int = 1
seq_length: int = 1
seq_kv_length: int = 1
is_causal: bool = False
in_dtype: str = "float16"
out_dtype: str = "float16"
accum_dtype: str = "float16"
def get_hardware_aware_configs(self, arch: TileDevice = None, topk: int = 10) -> List[Hint]:
"""
Retrieves optimized hardware-aware configurations.
Args:
arch (TileDevice, optional): The target hardware architecture.
topk (int, optional): Number of top configurations to consider.
Returns:
List[Hint]: A list of optimization hints for hardware acceleration.
"""
roller_hints = get_roller_hints_from_output_nodes(self.output_nodes, arch=arch, topk=topk)
return roller_hints
def initialize_function(self) -> None:
"""
Defines and initializes the matrix multiplication computation.
This method sets up placeholders for input matrices, computes
the matrix multiplication using TVM's compute API,
and optionally applies bias and type casting.
Raises:
AssertionError: If M, N, or K are not positive integers.
"""
batch_size = self.batch_size
num_heads = self.num_heads
head_dim = self.head_dim
seq_length = self.seq_length
seq_kv_length = self.seq_kv_length
in_dtype = self.in_dtype
out_dtype = self.out_dtype
accum_dtype = self.accum_dtype
# Equalize the input shaps into a matmul shape
QK_B, QK_M, QK_N, QK_K = batch_size * num_heads, seq_length, seq_kv_length, head_dim
SV_B, SV_M, SV_N, SV_K = batch_size * num_heads, seq_length, head_dim, seq_kv_length
# Define tensor shapes based on transpose flags
def create_matmul(B, M, N, K):
# Define tensor shapes based on transpose flags
input_shape = (B, M, K)
weight_shape = (B, N, K)
output_shape = (B, M, N) # Shape of output matrix C
# Create TVM placeholders for input tensors
A = te.placeholder(input_shape, name="A", dtype=in_dtype) # Input matrix A
B = te.placeholder(weight_shape, name="B", dtype=in_dtype) # Weight matrix B
# Define a reduction axis for matrix multiplication
k = te.reduce_axis((0, K), name="k")
def _compute_matmul(b, i, j):
"""
Compute function for matrix multiplication.
Args:
i (int): Row index.
j (int): Column index.
Returns:
Computed value for C[i, j] as a sum over the reduction axis.
"""
A_indices = [b, i, k]
B_indices = [b, j, k]
return te.sum(
A[tuple(A_indices)].astype(accum_dtype) *
B[tuple(B_indices)].astype(accum_dtype),
axis=k)
# Compute matrix multiplication result
C = te.compute(
output_shape,
fcompute=_compute_matmul,
name="C",
)
# Optionally cast the output to a different type
if out_dtype != accum_dtype:
C = te.compute(
output_shape,
lambda b, i, j: C[b, i, j].astype(out_dtype),
name="D",
)
args = [A, B, C]
return te.create_prim_func(args)
MMA0_prim_func = create_matmul(QK_B, QK_M, QK_N, QK_K)
MMA1_prim_func = create_matmul(SV_B, SV_M, SV_N, SV_K)
self.set_function([MMA0_prim_func, MMA1_prim_func])
def create_node_from_function(func, name):
tensorized_func, tags = get_tensorized_func_and_tags(func, self.arch.target)
assert tags is not None
return PrimFuncNode(tensorized_func, name=name, tags=tags)
node_0 = create_node_from_function(MMA0_prim_func, name="MMA0")
node_1 = create_node_from_function(MMA1_prim_func, name="MMA1")
# connect the two nodes
edge = Edge(node_0, node_1, 0, 0)
node_0._out_edges.append(edge)
node_1.set_inputs(0, edge)
output_nodes = [OutputNode(node_1)]
self.set_output_nodes(output_nodes)
def params_as_dict(self):
"""
Returns the template parameters as a dictionary.
Returns:
dict: Dictionary containing template parameter values.
"""
return {
"M": self.M,
"N": self.N,
"K": self.K,
"trans_A": self.trans_A,
"trans_B": self.trans_B,
"in_dtype": self.in_dtype,
"out_dtype": self.out_dtype,
"accum_dtype": self.accum_dtype,
"with_bias": self.with_bias,
}
@property
def class_attributes(self):
"""
Returns the class attributes in dictionary form.
Returns:
dict: Dictionary of class attributes.
"""
return self.params_as_dict()
def __repr__(self) -> str:
"""
Returns a string representation of the class instance.
Returns:
str: A formatted string representation of the class.
"""
cls_name = self.__class__.__name__
fields = self.class_attributes
field_str = ", ".join(f"{key}={value!r}" for key, value in fields.items())
return f"{cls_name}({field_str})"
......@@ -7,6 +7,7 @@ from tvm.tir import PrimFunc
from .arch import TileDevice
from .roller.policy import TensorCorePolicy, DefaultPolicy
from .roller.hint import Hint
from .roller.node import OutputNode
from .matmul_analysis import get_tensorized_func_and_tags
import logging
......@@ -56,7 +57,7 @@ def get_roller_hints_from_func(func_or_module: Union[tir.PrimFunc, IRModule],
else:
return None
else:
policy = DefaultPolicy(func=func, arch=arch)
policy = DefaultPolicy.from_prim_func(func=func, arch=arch)
tensorized_func = None
try:
tensorized_func, tags = get_tensorized_func_and_tags(
......@@ -65,10 +66,31 @@ def get_roller_hints_from_func(func_or_module: Union[tir.PrimFunc, IRModule],
logger.debug("Get tensorized func and tags failed: ", e_msg)
tags = None
if tags and tensorized_func:
policy = TensorCorePolicy(func=tensorized_func, arch=arch, tags=tags)
policy = TensorCorePolicy.from_prim_func(func=tensorized_func, arch=arch, tags=tags)
return policy.emit_config(topk)
def get_roller_hints_from_output_nodes(
output_nodes: List[OutputNode],
arch: TileDevice,
topk: int = 10,
extra_tags: Optional[List[str]] = None) -> Optional[List[Hint]]:
assert isinstance(output_nodes, list), "The input should be a list of functions."
lints = []
try:
policy = TensorCorePolicy.from_output_nodes(output_nodes, arch=arch, tags=None)
lints = policy.emit_config(topk)
except Exception as e_msg:
logger.debug(f"Generate hints from output nodes failed: {e_msg}",
"fallback to default policy")
if len(lints) == 0:
policy = DefaultPolicy.from_output_nodes(output_nodes, arch=arch, tags=None)
lints = policy.emit_config(topk)
return lints
def retrieve_func_from_module(ir_module: IRModule) -> PrimFunc:
if not isinstance(ir_module, IRModule):
raise ValueError("Not supported type: ", type(ir_module))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment