Unverified Commit 29051439 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Lint] Phaseout Yapf format and embrace ruff format (#1417)

parent e84b24bc
...@@ -8,7 +8,6 @@ def is_metal_arch(arch: TileDevice) -> bool: ...@@ -8,7 +8,6 @@ def is_metal_arch(arch: TileDevice) -> bool:
class METAL(TileDevice): class METAL(TileDevice):
def __init__(self, target: Target | str): def __init__(self, target: Target | str):
if isinstance(target, str): if isinstance(target, str):
target = Target(target) target = Target(target)
...@@ -16,6 +15,6 @@ class METAL(TileDevice): ...@@ -16,6 +15,6 @@ class METAL(TileDevice):
__all__ = [ __all__ = [
'is_metal_arch', "is_metal_arch",
'METAL', "METAL",
] ]
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
# Modifications Copyright (c) Microsoft. # Modifications Copyright (c) Microsoft.
# The code below is mostly copied from apache/tvm common_schedules.py in dlight. # The code below is mostly copied from apache/tvm common_schedules.py in dlight.
"""Common schedule strategies for TIR.""" """Common schedule strategies for TIR."""
from typing import Callable from typing import Callable
from tvm import tir from tvm import tir
......
# pylint: disable=missing-docstring, invalid-name # pylint: disable=missing-docstring, invalid-name
"""A GEMM schedule rule for GPU operators.""" """A GEMM schedule rule for GPU operators."""
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
...@@ -157,8 +158,7 @@ def find_last_producer_from_buffer(sch, main_block, buffer: tir.Buffer) -> Block ...@@ -157,8 +158,7 @@ def find_last_producer_from_buffer(sch, main_block, buffer: tir.Buffer) -> Block
return block return block
def find_arg_idx_from_buffer_chain(sch: tir.Schedule, main_block: tir.schedule.BlockRV, def find_arg_idx_from_buffer_chain(sch: tir.Schedule, main_block: tir.schedule.BlockRV, buffer: tir.Buffer) -> int:
buffer: tir.Buffer) -> int:
"""traverse to find the arg index from the buffer""" """traverse to find the arg index from the buffer"""
producers = sch.get_producers(main_block) producers = sch.get_producers(main_block)
...@@ -226,9 +226,7 @@ def make_iter_fusion_index_map( ...@@ -226,9 +226,7 @@ def make_iter_fusion_index_map(
else: else:
fused_iters[trait.kind] = v_i fused_iters[trait.kind] = v_i
final_indices: list[tir.PrimExpr] = [ final_indices: list[tir.PrimExpr] = [fused_iters.get(kind, tir.IntImm(traits[0].extent.dtype, 0)) for kind in kind_order]
fused_iters.get(kind, tir.IntImm(traits[0].extent.dtype, 0)) for kind in kind_order
]
return tir.IndexMap(input_iters, final_indices, None) return tir.IndexMap(input_iters, final_indices, None)
...@@ -307,8 +305,7 @@ def detect_iter_traits(block: tir.Block) -> tuple[list[IterTrait]] | None: ...@@ -307,8 +305,7 @@ def detect_iter_traits(block: tir.Block) -> tuple[list[IterTrait]] | None:
return A_traits, B_traits, C_traits, block_traits return A_traits, B_traits, C_traits, block_traits
def get_index_map(block: tir.Block, def get_index_map(block: tir.Block, layout: list[str] | None = None) -> tuple[tir.IndexMap, ...] | None:
layout: list[str] | None = None) -> tuple[tir.IndexMap, ...] | None:
"""Get index maps for the block """Get index maps for the block
Parameters Parameters
...@@ -343,10 +340,7 @@ def get_index_map(block: tir.Block, ...@@ -343,10 +340,7 @@ def get_index_map(block: tir.Block,
return axes return axes
def is_common_reduce(var: Var) -> bool: def is_common_reduce(var: Var) -> bool:
for iter_var in block.iter_vars: return any(iter_var.var == var and iter_var.iter_type == IterVar.CommReduce for iter_var in block.iter_vars)
if iter_var.var == var and iter_var.iter_type == IterVar.CommReduce:
return True
return False
def has_common_reduce(var: Var) -> bool: def has_common_reduce(var: Var) -> bool:
vars = collect_vars_from_expr(var) vars = collect_vars_from_expr(var)
...@@ -384,17 +378,17 @@ def get_index_map(block: tir.Block, ...@@ -384,17 +378,17 @@ def get_index_map(block: tir.Block,
if kind == "C": if kind == "C":
return [IterKind.kIter_S, primary_iter, secondary_iter] return [IterKind.kIter_S, primary_iter, secondary_iter]
else: else:
return ([IterKind.kIter_S, spatial_iter, reduction_iter] if check_last_trait(region) return (
else [IterKind.kIter_S, reduction_iter, spatial_iter]) [IterKind.kIter_S, spatial_iter, reduction_iter]
if check_last_trait(region)
else [IterKind.kIter_S, reduction_iter, spatial_iter]
)
else: else:
raise ValueError(f"Unknown layout {layout}") raise ValueError(f"Unknown layout {layout}")
A_index_map = make_iter_fusion_index_map( A_index_map = make_iter_fusion_index_map(A_traits, infer_layout(layout[0], block.reads[0].region, kind="A"))
A_traits, infer_layout(layout[0], block.reads[0].region, kind="A")) B_index_map = make_iter_fusion_index_map(B_traits, infer_layout(layout[1], block.reads[1].region, kind="B"))
B_index_map = make_iter_fusion_index_map( C_index_map = make_iter_fusion_index_map(C_traits, infer_layout(layout[2], block.writes[0].region, kind="C"))
B_traits, infer_layout(layout[1], block.reads[1].region, kind="B"))
C_index_map = make_iter_fusion_index_map(
C_traits, infer_layout(layout[2], block.writes[0].region, kind="C"))
matmul_index_map = make_iter_fusion_index_map( matmul_index_map = make_iter_fusion_index_map(
block_traits, block_traits,
...@@ -429,8 +423,7 @@ def get_dequantize_block(sch, blocks) -> BlockRV | None: ...@@ -429,8 +423,7 @@ def get_dequantize_block(sch, blocks) -> BlockRV | None:
has_uint_input = any("uint" in str(region.buffer.dtype) for region in block_stmt.reads) has_uint_input = any("uint" in str(region.buffer.dtype) for region in block_stmt.reads)
if not has_uint_input: if not has_uint_input:
return False return False
return not (len(block_stmt.writes) != 1 or return not (len(block_stmt.writes) != 1 or "float" not in str(block_stmt.writes[0].buffer.dtype))
"float" not in str(block_stmt.writes[0].buffer.dtype))
dequantize_blocks = [block for block in blocks if is_dequantize(block)] dequantize_blocks = [block for block in blocks if is_dequantize(block)]
return dequantize_blocks[0] if len(dequantize_blocks) == 1 else None return dequantize_blocks[0] if len(dequantize_blocks) == 1 else None
...@@ -452,8 +445,7 @@ def is_identity_or_transpose_block(block_stmt: tir.Block) -> bool: ...@@ -452,8 +445,7 @@ def is_identity_or_transpose_block(block_stmt: tir.Block) -> bool:
return None return None
axes.extend(undefined_vars(r.min)) axes.extend(undefined_vars(r.min))
# remove trivial axis # remove trivial axis
trivial_vars = set( trivial_vars = set(iter_var.var for iter_var in block_stmt.iter_vars if _is_one(iter_var.dom.extent))
iter_var.var for iter_var in block_stmt.iter_vars if _is_one(iter_var.dom.extent))
axes = [axis for axis in axes if axis not in trivial_vars] axes = [axis for axis in axes if axis not in trivial_vars]
# remove duplicate axis # remove duplicate axis
axes = [var for i, var in enumerate(axes) if i == 0 or var != axes[i - 1]] axes = [var for i, var in enumerate(axes) if i == 0 or var != axes[i - 1]]
...@@ -462,8 +454,7 @@ def is_identity_or_transpose_block(block_stmt: tir.Block) -> bool: ...@@ -462,8 +454,7 @@ def is_identity_or_transpose_block(block_stmt: tir.Block) -> bool:
lhs_access_vars = get_access_vars(block_stmt.reads[0].region)[-2:] lhs_access_vars = get_access_vars(block_stmt.reads[0].region)[-2:]
rhs_access_vars = get_access_vars(block_stmt.writes[0].region)[-2:] rhs_access_vars = get_access_vars(block_stmt.writes[0].region)[-2:]
is_identity = list(lhs_access_vars) == list(rhs_access_vars) is_identity = list(lhs_access_vars) == list(rhs_access_vars)
is_transpose = list(lhs_access_vars) != list(rhs_access_vars) and set(lhs_access_vars) == set( is_transpose = list(lhs_access_vars) != list(rhs_access_vars) and set(lhs_access_vars) == set(rhs_access_vars)
rhs_access_vars)
return is_identity, is_transpose return is_identity, is_transpose
...@@ -491,9 +482,7 @@ def inline_transpose_block(sch: tir.Schedule, blocks: list[tir.schedule.BlockRV] ...@@ -491,9 +482,7 @@ def inline_transpose_block(sch: tir.Schedule, blocks: list[tir.schedule.BlockRV]
return result_blocks return result_blocks
def normalize_to_matmul(sch: tir.Schedule, def normalize_to_matmul(sch: tir.Schedule, main_block: BlockRV, layout: list[str] | None = None) -> tir.Schedule | None:
main_block: BlockRV,
layout: list[str] | None = None) -> tir.Schedule | None:
if layout is None: if layout is None:
layout = ["n", "t", "n"] layout = ["n", "t", "n"]
block_stmt = sch.get(main_block) block_stmt = sch.get(main_block)
...@@ -526,7 +515,7 @@ def get_tensorized_func_and_tags( ...@@ -526,7 +515,7 @@ def get_tensorized_func_and_tags(
allow_gemv: bool = False, allow_gemv: bool = False,
) -> tuple[tir.PrimFunc, dict[str, list[int] | int]]: ) -> tuple[tir.PrimFunc, dict[str, list[int] | int]]:
""" """
transform function to matmul if necessary (e.g. transform conv2d with im2col) transform function to matmul if necessary (e.g. transform conv2d with im2col)
""" """
if layout is None: if layout is None:
layout = ["a", "a", "a"] layout = ["a", "a", "a"]
...@@ -543,10 +532,7 @@ def get_tensorized_func_and_tags( ...@@ -543,10 +532,7 @@ def get_tensorized_func_and_tags(
conditions = [] conditions = []
conditions.append(len(block_stmt.reads) == 2) conditions.append(len(block_stmt.reads) == 2)
conditions.append(len(block_stmt.writes) == 1) conditions.append(len(block_stmt.writes) == 1)
conditions.append( conditions.append(len(collect_block_iter_vars_used_in_access_region(block_stmt, block_stmt.writes[0].region)) > 0)
len(
collect_block_iter_vars_used_in_access_region(block_stmt,
block_stmt.writes[0].region)) > 0)
return all(conditions) return all(conditions)
# step2. transform function to tensorcore matmul (e.g. conv2d with im2col) # step2. transform function to tensorcore matmul (e.g. conv2d with im2col)
...@@ -592,10 +578,7 @@ def get_tensorized_func_and_tags( ...@@ -592,10 +578,7 @@ def get_tensorized_func_and_tags(
return axes return axes
def is_common_reduce(var: Var) -> bool: def is_common_reduce(var: Var) -> bool:
for iter_var in block_stmt.iter_vars: return any(iter_var.var == var and iter_var.iter_type == IterVar.CommReduce for iter_var in block_stmt.iter_vars)
if iter_var.var == var and iter_var.iter_type == IterVar.CommReduce:
return True
return False
def has_common_reduce(var: Var) -> bool: def has_common_reduce(var: Var) -> bool:
vars = collect_vars_from_expr(var) vars = collect_vars_from_expr(var)
...@@ -626,7 +609,7 @@ def get_tensorized_func_and_tags( ...@@ -626,7 +609,7 @@ def get_tensorized_func_and_tags(
# When the func is a dequantize like ops, we should consider the M # When the func is a dequantize like ops, we should consider the M
require_block_reduce = False require_block_reduce = False
# And we only support float16 for now # And we only support float16 for now
if (hasattr(func.attrs, "dequantize_info") and in_dtype in ["bfloat16", "float16"]): if hasattr(func.attrs, "dequantize_info") and in_dtype in ["bfloat16", "float16"]:
for arg in func.params: for arg in func.params:
inp_shape = func.buffer_map[arg].shape inp_shape = func.buffer_map[arg].shape
M = inp_shape[0] M = inp_shape[0]
...@@ -645,9 +628,7 @@ def get_tensorized_func_and_tags( ...@@ -645,9 +628,7 @@ def get_tensorized_func_and_tags(
if target.kind.name == "cuda" and check_sm_version(target.arch) >= 70: if target.kind.name == "cuda" and check_sm_version(target.arch) >= 70:
in_dtype, out_dtype = get_in_out_dtypes(block_stmt) in_dtype, out_dtype = get_in_out_dtypes(block_stmt)
if not is_tensorcore_supported_precision(in_dtype, out_dtype, arch=get_arch(target)): if not is_tensorcore_supported_precision(in_dtype, out_dtype, arch=get_arch(target)):
logger.debug( logger.debug(f"The input and output dtype ({in_dtype}, {out_dtype})is not supported by tensorcore")
f"The input and output dtype ({in_dtype}, {out_dtype})is not supported by tensorcore"
)
return func, None return func, None
# reindex and transform functions # reindex and transform functions
...@@ -676,7 +657,7 @@ def get_tensorized_func_and_tags( ...@@ -676,7 +657,7 @@ def get_tensorized_func_and_tags(
else: else:
raise ValueError(f"Unknown IterVar type {iter_type}") raise ValueError(f"Unknown IterVar type {iter_type}")
if (isinstance(extent, tir.expr.IntImm) and extent.value < minimal_tensorize_threshold): if isinstance(extent, tir.expr.IntImm) and extent.value < minimal_tensorize_threshold:
return func, None return func, None
tags = analysis_tensorcore_tags(sch, main_block, target) tags = analysis_tensorcore_tags(sch, main_block, target)
return sch.mod["main"], tags return sch.mod["main"], tags
...@@ -686,8 +667,10 @@ def get_tensorized_func_and_tags( ...@@ -686,8 +667,10 @@ def get_tensorized_func_and_tags(
def get_propagate_map(trans: bool = True, dtype="float16", matrix_name="A", index_dtype="int32"): def get_propagate_map(trans: bool = True, dtype="float16", matrix_name="A", index_dtype="int32"):
from bitblas.tl.mma_layout import ( # pylint: disable=import-outside-toplevel from bitblas.tl.mma_layout import ( # pylint: disable=import-outside-toplevel
ldmatrix_32x8_to_shared_16x16_layout, ldmatrix_trans_32x8_to_shared_16x16_layout, ldmatrix_32x8_to_shared_16x16_layout,
ldmatrix_32x16_to_shared_16x32_layout_a, ldmatrix_32x16_to_shared_16x32_layout_b, ldmatrix_trans_32x8_to_shared_16x16_layout,
ldmatrix_32x16_to_shared_16x32_layout_a,
ldmatrix_32x16_to_shared_16x32_layout_b,
) )
assert dtype in [ assert dtype in [
...@@ -727,9 +710,7 @@ def get_propagate_map(trans: bool = True, dtype="float16", matrix_name="A", inde ...@@ -727,9 +710,7 @@ def get_propagate_map(trans: bool = True, dtype="float16", matrix_name="A", inde
return ldmatrix_layout(thread_id, local_id) return ldmatrix_layout(thread_id, local_id)
if dtype in ["bfloat16", "float16"]: if dtype in ["bfloat16", "float16"]:
ldmatrix_index_map = ( ldmatrix_index_map = ldmatrix_trans_permutation_16x16_32x8_16x16 if trans else ldmatrix_permutation_16x16_32x8_16x16
ldmatrix_trans_permutation_16x16_32x8_16x16
if trans else ldmatrix_permutation_16x16_32x8_16x16)
else: else:
ldmatrix_index_map = ldmatrix_permutation_16x32_32x16_32x16 ldmatrix_index_map = ldmatrix_permutation_16x32_32x16_32x16
...@@ -744,7 +725,6 @@ def get_propagate_map(trans: bool = True, dtype="float16", matrix_name="A", inde ...@@ -744,7 +725,6 @@ def get_propagate_map(trans: bool = True, dtype="float16", matrix_name="A", inde
# Ladder weight propagation, which can be used to avoid the ldmatrix # Ladder weight propagation, which can be used to avoid the ldmatrix
# Instructions. # Instructions.
def get_ladder_stage3_map(dtype="float16", index_dtype="int32"): def get_ladder_stage3_map(dtype="float16", index_dtype="int32"):
def shared_32x8_to_mma_32x8_layout(i, j): def shared_32x8_to_mma_32x8_layout(i, j):
thread_id = (i % 8) * 4 + (j // 2) thread_id = (i % 8) * 4 + (j // 2)
local_id = (i // 8) * 2 + (j % 2) local_id = (i // 8) * 2 + (j % 2)
...@@ -837,8 +817,7 @@ def layout_propagate_chain( ...@@ -837,8 +817,7 @@ def layout_propagate_chain(
scaling_factor = 1 scaling_factor = 1
for i, j in zip(write.buffer.shape, read.buffer.shape): for i, j in zip(write.buffer.shape, read.buffer.shape):
scaling_factor *= i // j scaling_factor *= i // j
final_indices = list( final_indices = list(index_map.map_indices(tmp_index_map.map_indices(write_indices)))
index_map.map_indices(tmp_index_map.map_indices(write_indices)))
final_indices[-1] = final_indices[-1] // scaling_factor final_indices[-1] = final_indices[-1] // scaling_factor
index_map = IndexMap( index_map = IndexMap(
write_indices, write_indices,
......
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
class Block: class Block:
def __init__(self, start, end, is_free): def __init__(self, start, end, is_free):
self.start = start self.start = start
self.end = end self.end = end
...@@ -21,7 +20,6 @@ class Block: ...@@ -21,7 +20,6 @@ class Block:
class BestFit: class BestFit:
def __init__(self, align=32): def __init__(self, align=32):
self.limit = 0 self.limit = 0
self.list = [] self.list = []
...@@ -31,16 +29,14 @@ class BestFit: ...@@ -31,16 +29,14 @@ class BestFit:
size = (size + self.align - 1) // self.align * self.align size = (size + self.align - 1) // self.align * self.align
found = None found = None
for block in self.list: for block in self.list:
if block.is_free and block.size() >= size and (not found or if block.is_free and block.size() >= size and (not found or found.size() > block.size()):
found.size() > block.size()):
found = block found = block
if found: if found:
found.is_free = False found.is_free = False
remain = found.size() - size remain = found.size() - size
if remain != 0: if remain != 0:
found.end -= remain found.end -= remain
self.list.insert( self.list.insert(self.list.index(found) + 1, Block(found.end, found.end + remain, True))
self.list.index(found) + 1, Block(found.end, found.end + remain, True))
return found return found
elif len(self.list) > 0 and self.list[-1].is_free: elif len(self.list) > 0 and self.list[-1].is_free:
add = size - self.list[-1].size() add = size - self.list[-1].size()
......
"""Hint definition for schedule""" """Hint definition for schedule"""
from tvm import DataType from tvm import DataType
from . import PrimFuncNode from . import PrimFuncNode
import numpy as np import numpy as np
...@@ -60,7 +61,7 @@ class Stride: ...@@ -60,7 +61,7 @@ class Stride:
strided_elem = original_shape strided_elem = original_shape
else: else:
assert self.ax < len(shape) assert self.ax < len(shape)
strided_elem = np.prod(shape[0:self.ax + 1]) * self.stride strided_elem = np.prod(shape[0 : self.ax + 1]) * self.stride
assert strided_elem >= original_shape assert strided_elem >= original_shape
return int(strided_elem) return int(strided_elem)
...@@ -217,7 +218,7 @@ class Hint: ...@@ -217,7 +218,7 @@ class Hint:
return dic return dic
@classmethod @classmethod
def from_dict(cls, dic: dict) -> 'Hint': def from_dict(cls, dic: dict) -> "Hint":
hint = cls() hint = cls()
for k, v in dic.items(): for k, v in dic.items():
setattr(hint, k, v) setattr(hint, k, v)
......
"""PrimFunc Wrapper and Block information Analaysis""" """PrimFunc Wrapper and Block information Analaysis"""
from __future__ import annotations from __future__ import annotations
import tvm import tvm
...@@ -31,7 +32,6 @@ def pre_order_traverse(block_analyzer, blocks, func): ...@@ -31,7 +32,6 @@ def pre_order_traverse(block_analyzer, blocks, func):
class BlockAnalyzer: class BlockAnalyzer:
def __init__(self, sch) -> None: def __init__(self, sch) -> None:
self.sch: tir.Schedule = sch self.sch: tir.Schedule = sch
self.block_infos: list[BlockInfo] = normalize_prim_func(self.sch) self.block_infos: list[BlockInfo] = normalize_prim_func(self.sch)
...@@ -92,7 +92,6 @@ class Edge: ...@@ -92,7 +92,6 @@ class Edge:
class Node: class Node:
def __init__(self, tags: dict | None = None, name: str = "Node") -> None: def __init__(self, tags: dict | None = None, name: str = "Node") -> None:
self.name = name self.name = name
if tags is None: if tags is None:
...@@ -177,7 +176,6 @@ class Node: ...@@ -177,7 +176,6 @@ class Node:
class PlaceHolderNode(Node): class PlaceHolderNode(Node):
def __init__(self, name=""): def __init__(self, name=""):
super().__init__(name="PlaceHolder_" + name) super().__init__(name="PlaceHolder_" + name)
...@@ -189,11 +187,7 @@ class PlaceHolderNode(Node): ...@@ -189,11 +187,7 @@ class PlaceHolderNode(Node):
class PrimFuncNode(Node): class PrimFuncNode(Node):
def __init__(self, prim_func: PrimFunc, tags: dict | None = None, name: str = "PrimFuncNode") -> None:
def __init__(self,
prim_func: PrimFunc,
tags: dict | None = None,
name: str = "PrimFuncNode") -> None:
super().__init__(tags, name=name) super().__init__(tags, name=name)
self.prim_func = self._specialize_func(prim_func) self.prim_func = self._specialize_func(prim_func)
self.sch: tir.Schedule = tir.Schedule(self.prim_func) self.sch: tir.Schedule = tir.Schedule(self.prim_func)
...@@ -227,7 +221,7 @@ class PrimFuncNode(Node): ...@@ -227,7 +221,7 @@ class PrimFuncNode(Node):
for dst_id, n in enumerate(inputs): for dst_id, n in enumerate(inputs):
if isinstance(n, Node): if isinstance(n, Node):
n = (n, 0) n = (n, 0)
assert (len(n) == 2) assert len(n) == 2
src_node, src_id = n[0], n[1] src_node, src_id = n[0], n[1]
edge = Edge(src_node, self, src_id, dst_id) edge = Edge(src_node, self, src_id, dst_id)
self._in_edges.append(edge) self._in_edges.append(edge)
...@@ -338,9 +332,8 @@ class PrimFuncNode(Node): ...@@ -338,9 +332,8 @@ class PrimFuncNode(Node):
if rstep is None: if rstep is None:
rstep = {} rstep = {}
shape = { shape = {
self.block_analyzer.get_output_buffers(block)[0].name: [ self.block_analyzer.get_output_buffers(block)[0].name: [tvm.arith.ConstIntBound(0, val - 1) for val in tile]
tvm.arith.ConstIntBound(0, val - 1) for val in tile for block in self.schedule_stages
] for block in self.schedule_stages
} }
return self.ana.infer(shape, rstep, targets) return self.ana.infer(shape, rstep, targets)
...@@ -356,10 +349,7 @@ class PrimFuncNode(Node): ...@@ -356,10 +349,7 @@ class PrimFuncNode(Node):
results.append(shapes[arg.name]) results.append(shapes[arg.name])
continue continue
# should not exceed original shape # should not exceed original shape
trimmed_shape = [ trimmed_shape = [self.extent_wrapper(i) for i in list(map(min, zip(shapes[arg.name], self.input_buffers[i].shape)))]
self.extent_wrapper(i)
for i in list(map(min, zip(shapes[arg.name], self.input_buffers[i].shape)))
]
results.append(trimmed_shape) results.append(trimmed_shape)
return results return results
...@@ -380,10 +370,8 @@ class PrimFuncNode(Node): ...@@ -380,10 +370,8 @@ class PrimFuncNode(Node):
propagate_shape = shapes[arg.name] propagate_shape = shapes[arg.name]
buffer_shape = args[i].shape buffer_shape = args[i].shape
if len(buffer_shape) > len(propagate_shape): if len(buffer_shape) > len(propagate_shape):
buffer_shape = buffer_shape[-len(propagate_shape):] buffer_shape = buffer_shape[-len(propagate_shape) :]
trimmed_shape = [ trimmed_shape = [self.extent_wrapper(j) for j in list(map(min, zip(propagate_shape, buffer_shape)))]
self.extent_wrapper(j) for j in list(map(min, zip(propagate_shape, buffer_shape)))
]
results.append(trimmed_shape) results.append(trimmed_shape)
return results return results
...@@ -412,10 +400,7 @@ class PrimFuncNode(Node): ...@@ -412,10 +400,7 @@ class PrimFuncNode(Node):
def get_reduce_inputs_dtype(self): def get_reduce_inputs_dtype(self):
if self.reduction_block is None: if self.reduction_block is None:
return {} return {}
return { return {b.name: tvm.DataType(b.dtype) for b in self.block_analyzer.get_input_buffers(self.reduction_block)}
b.name: tvm.DataType(b.dtype)
for b in self.block_analyzer.get_input_buffers(self.reduction_block)
}
@functools.lru_cache @functools.lru_cache
def infer_tensorcore_axis(self) -> tuple[int]: def infer_tensorcore_axis(self) -> tuple[int]:
...@@ -425,8 +410,7 @@ class PrimFuncNode(Node): ...@@ -425,8 +410,7 @@ class PrimFuncNode(Node):
C_ax_m, C_ax_n = self.get_tag("tensorcore_config") C_ax_m, C_ax_n = self.get_tag("tensorcore_config")
wmma_m, wmma_n, wmma_k = [16, 16, 16] # just for testing, any number is ok wmma_m, wmma_n, wmma_k = [16, 16, 16] # just for testing, any number is ok
output_buffer_shape = ( output_buffer_shape = self.block_analyzer.sch.get(self.reduction_block).writes[0].buffer.shape
self.block_analyzer.sch.get(self.reduction_block).writes[0].buffer.shape)
valid_region = [] valid_region = []
for region in output_buffer_shape: for region in output_buffer_shape:
if region.value == 1: if region.value == 1:
...@@ -438,8 +422,7 @@ class PrimFuncNode(Node): ...@@ -438,8 +422,7 @@ class PrimFuncNode(Node):
def get_cl_shapes(c_ax_m, c_ax_n, num_nvalid_regions): def get_cl_shapes(c_ax_m, c_ax_n, num_nvalid_regions):
spatial_dim = self.get_space_dim() spatial_dim = self.get_space_dim()
assert len(valid_region) == len( assert len(valid_region) == len(spatial_dim), f" {valid_region} mismatch with {spatial_dim}"
spatial_dim), f" {valid_region} mismatch with {spatial_dim}"
cl_shapes = [1] * len(spatial_dim) cl_shapes = [1] * len(spatial_dim)
cl_shapes[c_ax_m - num_nvalid_regions] = wmma_m cl_shapes[c_ax_m - num_nvalid_regions] = wmma_m
cl_shapes[c_ax_n - num_nvalid_regions] = wmma_n cl_shapes[c_ax_n - num_nvalid_regions] = wmma_n
...@@ -467,9 +450,11 @@ class PrimFuncNode(Node): ...@@ -467,9 +450,11 @@ class PrimFuncNode(Node):
shapes, _ = self.propagate(shape, rstep) shapes, _ = self.propagate(shape, rstep)
def is_broadcast_pattern(buffer, output_buffer): def is_broadcast_pattern(buffer, output_buffer):
return (buffer in self.args and return (
len(shapes[output_buffer.name]) > len(shapes[buffer.name]) and buffer in self.args
np.prod(shapes[output_buffer.name]) > np.prod(shapes[buffer.name])) and len(shapes[output_buffer.name]) > len(shapes[buffer.name])
and np.prod(shapes[output_buffer.name]) > np.prod(shapes[buffer.name])
)
def is_after_reduce_stage(block): def is_after_reduce_stage(block):
if not self.reduction_block: if not self.reduction_block:
...@@ -491,8 +476,8 @@ class PrimFuncNode(Node): ...@@ -491,8 +476,8 @@ class PrimFuncNode(Node):
output_buffer = self.block_analyzer.get_output_buffers(block)[0] output_buffer = self.block_analyzer.get_output_buffers(block)[0]
for buffer in self.block_analyzer.get_input_buffers(block): for buffer in self.block_analyzer.get_input_buffers(block):
cache = buffer.name not in cached_tensor and ( cache = buffer.name not in cached_tensor and (
is_broadcast_pattern(buffer, output_buffer) or is_broadcast_pattern(buffer, output_buffer) or self.block_analyzer.get_block_info(block).is_reduction()
self.block_analyzer.get_block_info(block).is_reduction()) )
if not cache: if not cache:
continue continue
cached_tensor.append(buffer.name) cached_tensor.append(buffer.name)
...@@ -500,8 +485,7 @@ class PrimFuncNode(Node): ...@@ -500,8 +485,7 @@ class PrimFuncNode(Node):
continue # cache after reduce op can often reuse buffer in reduce stage continue # cache after reduce op can often reuse buffer in reduce stage
if buffer.name in stride_map: if buffer.name in stride_map:
num_elem = stride_map[buffer.name].compute_elements_from_shape( num_elem = stride_map[buffer.name].compute_elements_from_shape(shapes[buffer.name])
shapes[buffer.name])
else: else:
num_elem = np.prod(shapes[buffer.name]) num_elem = np.prod(shapes[buffer.name])
buffer_len = num_elem * int((tvm.DataType(buffer.dtype).bits + 7) // 8) buffer_len = num_elem * int((tvm.DataType(buffer.dtype).bits + 7) // 8)
...@@ -514,7 +498,6 @@ class PrimFuncNode(Node): ...@@ -514,7 +498,6 @@ class PrimFuncNode(Node):
class OutputNode(Node): class OutputNode(Node):
def __init__(self, node, id=0): def __init__(self, node, id=0):
super().__init__(name="OutputNode") super().__init__(name="OutputNode")
# connect node and output node # connect node and output node
...@@ -549,15 +532,16 @@ def topo_order(list_of_nodes) -> list[Node]: ...@@ -549,15 +532,16 @@ def topo_order(list_of_nodes) -> list[Node]:
input_ready_count[dst_node] = len(dst_node.inputs) input_ready_count[dst_node] = len(dst_node.inputs)
list_of_nodes.append(dst_node) list_of_nodes.append(dst_node)
input_ready_count[dst_node] -= 1 input_ready_count[dst_node] -= 1
assert (input_ready_count[dst_node] >= 0) assert input_ready_count[dst_node] >= 0
if input_ready_count[dst_node] == 0: if input_ready_count[dst_node] == 0:
ready.append(dst_node) ready.append(dst_node)
assert (len(list_of_nodes) == len(output_list)) assert len(list_of_nodes) == len(output_list)
return output_list return output_list
def find_topo_sort_priority(output_node_list) -> list[Node]: def find_topo_sort_priority(output_node_list) -> list[Node]:
import sys import sys
sys.setrecursionlimit(10000) sys.setrecursionlimit(10000)
def topo_sort_get_layer(node, topo_layer): def topo_sort_get_layer(node, topo_layer):
...@@ -576,9 +560,7 @@ def find_topo_sort_priority(output_node_list) -> list[Node]: ...@@ -576,9 +560,7 @@ def find_topo_sort_priority(output_node_list) -> list[Node]:
if node in visited: if node in visited:
return return
visited.add(node) visited.add(node)
ordered_input_nodes = sorted([edge.src_node for edge in node.inputs], ordered_input_nodes = sorted([edge.src_node for edge in node.inputs], key=lambda n: topo_layer[n], reverse=True)
key=lambda n: topo_layer[n],
reverse=True)
for n in ordered_input_nodes: for n in ordered_input_nodes:
topo_sort_dfs(n, visited, topo_order) topo_sort_dfs(n, visited, topo_order)
topo_order.append(node) topo_order.append(node)
...@@ -591,7 +573,6 @@ def find_topo_sort_priority(output_node_list) -> list[Node]: ...@@ -591,7 +573,6 @@ def find_topo_sort_priority(output_node_list) -> list[Node]:
def find_topo_sort(output_node_list) -> list[Node]: def find_topo_sort(output_node_list) -> list[Node]:
def topo_sort_dfs(node, visited, topo_order): def topo_sort_dfs(node, visited, topo_order):
if node in visited: if node in visited:
return return
......
"""Policy for cuda core schedule""" """Policy for cuda core schedule"""
from __future__ import annotations from __future__ import annotations
import functools import functools
import math import math
...@@ -36,20 +37,14 @@ class DefaultPolicy: ...@@ -36,20 +37,14 @@ class DefaultPolicy:
self.rasterization = NoRasterization() self.rasterization = NoRasterization()
@classmethod @classmethod
def from_prim_func(cls, def from_prim_func(cls, func: tvm.tir.PrimFunc, arch: TileDevice, tags: dict | None = None, name: str = "PrimFuncNode"):
func: tvm.tir.PrimFunc,
arch: TileDevice,
tags: dict | None = None,
name: str = "PrimFuncNode"):
return cls(arch, tags)._init_with_prim_func(func, name) return cls(arch, tags)._init_with_prim_func(func, name)
@classmethod @classmethod
def from_output_nodes(cls, nodes: list[OutputNode], arch: TileDevice, tags: dict | None = None): def from_output_nodes(cls, nodes: list[OutputNode], arch: TileDevice, tags: dict | None = None):
return cls(arch, tags)._init_with_output_nodes(nodes) return cls(arch, tags)._init_with_output_nodes(nodes)
def _init_with_prim_func(self, def _init_with_prim_func(self, func: tvm.tir.PrimFunc, name: str = "PrimFuncNode") -> DefaultPolicy:
func: tvm.tir.PrimFunc,
name: str = "PrimFuncNode") -> DefaultPolicy:
if func is not None and isinstance(func, tvm.tir.PrimFunc): if func is not None and isinstance(func, tvm.tir.PrimFunc):
self.func = func self.func = func
self.prim_func_node = PrimFuncNode(self.func, tags=self.tags, name=name) self.prim_func_node = PrimFuncNode(self.func, tags=self.tags, name=name)
...@@ -60,9 +55,7 @@ class DefaultPolicy: ...@@ -60,9 +55,7 @@ class DefaultPolicy:
return self return self
def _init_with_output_nodes(self, output_nodes: list[OutputNode]): def _init_with_output_nodes(self, output_nodes: list[OutputNode]):
self.ordered_nodes = list( self.ordered_nodes = list(filter(lambda n: not n.is_placeholder() and not n.is_output(), find_topo_sort(output_nodes)))
filter(lambda n: not n.is_placeholder() and not n.is_output(),
find_topo_sort(output_nodes)))
for node in self.ordered_nodes: for node in self.ordered_nodes:
node.update_tags(self.tags) node.update_tags(self.tags)
...@@ -102,13 +95,14 @@ class DefaultPolicy: ...@@ -102,13 +95,14 @@ class DefaultPolicy:
def dfs_smem_tile(self, init_tile, rstep_map) -> Iterable[TileDict]: def dfs_smem_tile(self, init_tile, rstep_map) -> Iterable[TileDict]:
_steps = [get_all_factors(n) for n in self.output_nodes[0].get_space_dim()] _steps = [get_all_factors(n) for n in self.output_nodes[0].get_space_dim()]
steps = [step[step.index(t):] for step, t in zip(_steps, init_tile)] steps = [step[step.index(t) :] for step, t in zip(_steps, init_tile)]
for i in range(len(steps)): for i in range(len(steps)):
added = list( added = list(
filter( filter(
lambda s: s < steps[i][-1] and s > steps[i][0] and s not in steps[i], lambda s: s < steps[i][-1] and s > steps[i][0] and s not in steps[i],
[2, 4, 8, 16, 32], [2, 4, 8, 16, 32],
)) )
)
steps[i].extend(added) steps[i].extend(added)
steps[i] = sorted(steps[i]) steps[i] = sorted(steps[i])
visited_tiles = {} visited_tiles = {}
...@@ -190,10 +184,7 @@ class DefaultPolicy: ...@@ -190,10 +184,7 @@ class DefaultPolicy:
""" """
tile_map = {} tile_map = {}
for node in self.output_nodes: for node in self.output_nodes:
tile_map[node] = [ tile_map[node] = [tile[i] * node.get_space_dim()[i] // self.output_nodes[0].get_space_dim()[i] for i in range(len(tile))]
tile[i] * node.get_space_dim()[i] // self.output_nodes[0].get_space_dim()[i]
for i in range(len(tile))
]
return tile_map return tile_map
def compute_workload_per_item(self, output_tile) -> float: def compute_workload_per_item(self, output_tile) -> float:
...@@ -304,8 +295,7 @@ class DefaultPolicy: ...@@ -304,8 +295,7 @@ class DefaultPolicy:
score = 0 score = 0
shape = node.propagate_inputs(tile, rstep=rstep) shape = node.propagate_inputs(tile, rstep=rstep)
for i, input_buffer in enumerate(node.input_buffers): for i, input_buffer in enumerate(node.input_buffers):
read_transaction_elements = self.arch.transaction_size[1] // ( read_transaction_elements = self.arch.transaction_size[1] // ((node.get_buffer_dtype(input_buffer).bits + 7) // 8)
(node.get_buffer_dtype(input_buffer).bits + 7) // 8)
score += sim( score += sim(
int(coalesced_factor(shape[i], input_buffer.shape)), int(coalesced_factor(shape[i], input_buffer.shape)),
read_transaction_elements, read_transaction_elements,
...@@ -380,17 +370,13 @@ class DefaultPolicy: ...@@ -380,17 +370,13 @@ class DefaultPolicy:
return None return None
return max(candidates, key=lambda x: x[1])[0] return max(candidates, key=lambda x: x[1])[0]
cur_rstep_id = { cur_rstep_id = {k.var.name: all_steps[k.var.name].index(rstep[k.var.name]) for k in node.raxis}
k.var.name: all_steps[k.var.name].index(rstep[k.var.name]) for k in node.raxis
}
new_rstep_map = rstep_map.copy() new_rstep_map = rstep_map.copy()
while True: while True:
new_rstep_id = _enlarge(cur_rstep_id) new_rstep_id = _enlarge(cur_rstep_id)
if new_rstep_id is None: if new_rstep_id is None:
break break
new_rstep_map[node] = { new_rstep_map[node] = {k.var.name: all_steps[k.var.name][new_rstep_id[k.var.name]] for k in node.raxis}
k.var.name: all_steps[k.var.name][new_rstep_id[k.var.name]] for k in node.raxis
}
old_rstep_map = td.rstep_map old_rstep_map = td.rstep_map
td.rstep_map = new_rstep_map td.rstep_map = new_rstep_map
smem_usage, _ = self._compute_shared_memory_usage(td) smem_usage, _ = self._compute_shared_memory_usage(td)
...@@ -434,15 +420,14 @@ class DefaultPolicy: ...@@ -434,15 +420,14 @@ class DefaultPolicy:
if edge.src_node.is_placeholder(): if edge.src_node.is_placeholder():
nbytes = (edge.src_node.get_dtype().bits + 7) // 8 nbytes = (edge.src_node.get_dtype().bits + 7) // 8
read_transaction_elements = self.arch.transaction_size[1] // nbytes read_transaction_elements = self.arch.transaction_size[1] // nbytes
traffic += coalesced_tensor_shape(input_shapes[i], edge.src_node.get_shape(), traffic += coalesced_tensor_shape(input_shapes[i], edge.src_node.get_shape(), read_transaction_elements) * nbytes
read_transaction_elements) * nbytes
for edge in node.outputs: for edge in node.outputs:
if edge.dst_node.is_output(): if edge.dst_node.is_output():
nbytes = (edge.src_node.get_dtype().bits + 7) // 8 nbytes = (edge.src_node.get_dtype().bits + 7) // 8
write_transaction_elements = self.arch.transaction_size[0] // nbytes write_transaction_elements = self.arch.transaction_size[0] // nbytes
traffic += coalesced_tensor_shape(output_shapes[edge.src_id], traffic += (
node.get_shape(edge.src_id), coalesced_tensor_shape(output_shapes[edge.src_id], node.get_shape(edge.src_id), write_transaction_elements) * nbytes
write_transaction_elements) * nbytes )
return traffic, op_tile_map return traffic, op_tile_map
...@@ -487,10 +472,7 @@ class DefaultPolicy: ...@@ -487,10 +472,7 @@ class DefaultPolicy:
cached_tensors_map = {} cached_tensors_map = {}
def can_free(node, out_id): def can_free(node, out_id):
for edge in node.outputs: return all(not (edge.src_id == out_id and edge.dst_node not in processed) for edge in node.outputs)
if edge.src_id == out_id and edge.dst_node not in processed:
return False
return True
for node in self.ordered_nodes: for node in self.ordered_nodes:
node_internal_bytes, cached_tensors_map[node] = self.infer_node_smem_usage(td, node) node_internal_bytes, cached_tensors_map[node] = self.infer_node_smem_usage(td, node)
...@@ -528,9 +510,7 @@ class DefaultPolicy: ...@@ -528,9 +510,7 @@ class DefaultPolicy:
Tuple[Dict, Dict] Tuple[Dict, Dict]
A tuple of dictionaries containing the output strides and tensor strides. A tuple of dictionaries containing the output strides and tensor strides.
""" """
output_strides = { output_strides = {int(i + len(node.input_buffers)): Stride() for i, _ in enumerate(node.output_buffers)}
int(i + len(node.input_buffers)): Stride() for i, _ in enumerate(node.output_buffers)
}
tensor_strides = {} tensor_strides = {}
return output_strides, tensor_strides return output_strides, tensor_strides
...@@ -551,8 +531,7 @@ class DefaultPolicy: ...@@ -551,8 +531,7 @@ class DefaultPolicy:
output_strides_map = {} output_strides_map = {}
tensor_strides_map = {} tensor_strides_map = {}
for node in self.ordered_nodes: for node in self.ordered_nodes:
output_strides_map[node], tensor_strides_map[node] = self.compute_node_stride_map( output_strides_map[node], tensor_strides_map[node] = self.compute_node_stride_map(node, td)
node, td)
td.output_strides_map, td.tensor_strides_map = output_strides_map, tensor_strides_map td.output_strides_map, td.tensor_strides_map = output_strides_map, tensor_strides_map
def compute_tile_dict(self, output_tile: list[int], rstep_map) -> TileDict: def compute_tile_dict(self, output_tile: list[int], rstep_map) -> TileDict:
...@@ -582,9 +561,7 @@ class DefaultPolicy: ...@@ -582,9 +561,7 @@ class DefaultPolicy:
output_shape = self.output_nodes[0].get_space_dim() output_shape = self.output_nodes[0].get_space_dim()
td.grid_size = int(np.prod([(y + x - 1) // x for x, y in zip(output_tile, output_shape)])) td.grid_size = int(np.prod([(y + x - 1) // x for x, y in zip(output_tile, output_shape)]))
# estimated reg usage # estimated reg usage
reg_usage = int(2 * max([ reg_usage = int(2 * max([np.prod(td.get_tile(node)) * node.get_dtype().bits / 32 for node in self.ordered_nodes]))
np.prod(td.get_tile(node)) * node.get_dtype().bits / 32 for node in self.ordered_nodes
]))
if reg_usage > self.arch.reg_cap: if reg_usage > self.arch.reg_cap:
td.valid = False td.valid = False
return td return td
...@@ -609,13 +586,10 @@ class DefaultPolicy: ...@@ -609,13 +586,10 @@ class DefaultPolicy:
for node in self.ordered_nodes: for node in self.ordered_nodes:
if np.prod(td.get_tile(node)) == 0: if np.prod(td.get_tile(node)) == 0:
return False return False
node_grid_size = np.prod([ node_grid_size = np.prod([(y + x - 1) // x for x, y in zip(td.get_tile(node), node.get_space_dim())])
(y + x - 1) // x for x, y in zip(td.get_tile(node), node.get_space_dim())
])
if node_grid_size != td.grid_size: if node_grid_size != td.grid_size:
return False return False
if (hasattr(node, "reduce_op") and node.reduce_op is not None and if hasattr(node, "reduce_op") and node.reduce_op is not None and len(node.reduce_op.axis) == len(td.output_tile):
len(node.reduce_op.axis) == len(td.output_tile)):
for i, tile_extent in enumerate(td.output_tile): for i, tile_extent in enumerate(td.output_tile):
if node.reduce_op.axis[i].dom.extent % tile_extent: if node.reduce_op.axis[i].dom.extent % tile_extent:
return False return False
...@@ -639,23 +613,22 @@ class DefaultPolicy: ...@@ -639,23 +613,22 @@ class DefaultPolicy:
node_space_sizes = [int(np.prod(td.get_tile(node))) for node in self.ordered_nodes] node_space_sizes = [int(np.prod(td.get_tile(node))) for node in self.ordered_nodes]
max_block_size = functools.reduce(math.gcd, node_space_sizes) max_block_size = functools.reduce(math.gcd, node_space_sizes)
if max_block_size < self.arch.warp_size * self.arch.sm_partition and max_block_size == min( if max_block_size < self.arch.warp_size * self.arch.sm_partition and max_block_size == min(node_space_sizes):
node_space_sizes): node_reduce_sizes = [int(np.prod(list(td.get_rstep(node).values()))) for node in self.ordered_nodes]
node_reduce_sizes = [
int(np.prod(list(td.get_rstep(node).values()))) for node in self.ordered_nodes
]
total_sizes = [x * y for x, y in zip(node_space_sizes, node_reduce_sizes)] total_sizes = [x * y for x, y in zip(node_space_sizes, node_reduce_sizes)]
max_possible_size = functools.reduce(math.gcd, total_sizes) max_possible_size = functools.reduce(math.gcd, total_sizes)
possible_block_sizes = list( possible_block_sizes = list(
filter( filter(
lambda x: x % max_block_size == 0 and x <= 1024, lambda x: x % max_block_size == 0 and x <= 1024,
get_all_factors(max_possible_size), get_all_factors(max_possible_size),
)) )
)
possible_block_sizes = list( possible_block_sizes = list(
filter( # either be a factor of space or cover fully cover the space filter( # either be a factor of space or cover fully cover the space
lambda x: all([x % s == 0 or s % x == 0 for s in node_space_sizes]), lambda x: all([x % s == 0 or s % x == 0 for s in node_space_sizes]),
possible_block_sizes, possible_block_sizes,
)) )
)
factor_ordered = sorted(possible_block_sizes, key=self.score_block_size) factor_ordered = sorted(possible_block_sizes, key=self.score_block_size)
return factor_ordered return factor_ordered
else: else:
...@@ -821,8 +794,7 @@ class DefaultPolicy: ...@@ -821,8 +794,7 @@ class DefaultPolicy:
vectorize_result = {} vectorize_result = {}
for tensor, shape in shapes.items(): for tensor, shape in shapes.items():
for v in vectorize_sizes: for v in vectorize_sizes:
if (is_shape_aligned(shape, block_size * v) and is_cont(shape, v) and if is_shape_aligned(shape, block_size * v) and is_cont(shape, v) and is_type_allowed(dtypes[tensor], v):
is_type_allowed(dtypes[tensor], v)):
vectorize_result[tensor] = v vectorize_result[tensor] = v
break break
return vectorize_result return vectorize_result
......
"""Policy for tensorcore schedule""" """Policy for tensorcore schedule"""
from __future__ import annotations from __future__ import annotations
import tvm import tvm
import numpy as np import numpy as np
...@@ -13,7 +14,6 @@ logger = logging.getLogger(__name__) ...@@ -13,7 +14,6 @@ logger = logging.getLogger(__name__)
class TensorCorePolicy(DefaultPolicy): class TensorCorePolicy(DefaultPolicy):
# this is the trick for wmma. # this is the trick for wmma.
# However, for int8 mma, the wmma_k should be 32. # However, for int8 mma, the wmma_k should be 32.
wmma_k: int = 16 wmma_k: int = 16
...@@ -70,9 +70,9 @@ class TensorCorePolicy(DefaultPolicy): ...@@ -70,9 +70,9 @@ class TensorCorePolicy(DefaultPolicy):
A_high_ax = min(A_ax_m, A_ax_k) A_high_ax = min(A_ax_m, A_ax_k)
B_high_ax = min(B_ax_n, B_ax_k) B_high_ax = min(B_ax_n, B_ax_k)
C_high_ax = min(C_ax_m, C_ax_n) C_high_ax = min(C_ax_m, C_ax_n)
A_stride = Stride(stride=np.prod(AS_shape[A_high_ax + 1:]) + offset, ax=A_high_ax) A_stride = Stride(stride=np.prod(AS_shape[A_high_ax + 1 :]) + offset, ax=A_high_ax)
B_stride = Stride(stride=np.prod(BS_shape[B_high_ax + 1:]) + offset, ax=B_high_ax) B_stride = Stride(stride=np.prod(BS_shape[B_high_ax + 1 :]) + offset, ax=B_high_ax)
C_stride = Stride(stride=np.prod(CS_shape[C_high_ax + 1:]) + offset, ax=C_high_ax) C_stride = Stride(stride=np.prod(CS_shape[C_high_ax + 1 :]) + offset, ax=C_high_ax)
return A_stride, B_stride, C_stride return A_stride, B_stride, C_stride
def infer_node_smem_usage(self, td: TileDict, node: PrimFuncNode): def infer_node_smem_usage(self, td: TileDict, node: PrimFuncNode):
...@@ -86,8 +86,7 @@ class TensorCorePolicy(DefaultPolicy): ...@@ -86,8 +86,7 @@ class TensorCorePolicy(DefaultPolicy):
# get reduce input size # get reduce input size
target_transaction = self.arch.transaction_size[0] * 2 target_transaction = self.arch.transaction_size[0] * 2
# 512 bytes // type bits # 512 bytes // type bits
reduce_input_dtype = node.get_buffer_dtype( reduce_input_dtype = node.get_buffer_dtype(node.block_analyzer.get_input_buffers(node.reduction_block)[0])
node.block_analyzer.get_input_buffers(node.reduction_block)[0])
basic = (target_transaction * 8) // reduce_input_dtype.bits basic = (target_transaction * 8) // reduce_input_dtype.bits
result = {} result = {}
...@@ -95,7 +94,7 @@ class TensorCorePolicy(DefaultPolicy): ...@@ -95,7 +94,7 @@ class TensorCorePolicy(DefaultPolicy):
iter_name = iter_info.var.name iter_name = iter_info.var.name
iter_dom = iter_info.dom.extent iter_dom = iter_info.dom.extent
if iter_dom % 16 > 0: if iter_dom % 16 > 0:
result[iter_name] = (16 if iter_dom < basic else basic) # for the case of padding result[iter_name] = 16 if iter_dom < basic else basic # for the case of padding
elif iter_dom % basic == 0: elif iter_dom % basic == 0:
result[iter_name] = basic result[iter_name] = basic
else: else:
...@@ -114,7 +113,6 @@ class TensorCorePolicy(DefaultPolicy): ...@@ -114,7 +113,6 @@ class TensorCorePolicy(DefaultPolicy):
return False return False
if _check_small_tile(td): if _check_small_tile(td):
smem_limit = min(self.arch.max_smem_usage // td.block_per_SM, self.arch.smem_cap) smem_limit = min(self.arch.max_smem_usage // td.block_per_SM, self.arch.smem_cap)
rstep_map = td.rstep_map.copy() rstep_map = td.rstep_map.copy()
...@@ -127,13 +125,10 @@ class TensorCorePolicy(DefaultPolicy): ...@@ -127,13 +125,10 @@ class TensorCorePolicy(DefaultPolicy):
return rstep return rstep
def _shared_memory_usage(td: TileDict): def _shared_memory_usage(td: TileDict):
return node.footprint(td.output_tile, new_rstep_map, return node.footprint(td.output_tile, new_rstep_map, td.tensor_strides_map[node])
td.tensor_strides_map[node])
def _score(rstep_id): def _score(rstep_id):
rstep = { rstep = {k.var.name: all_steps[k.var.name][rstep_id[k.var.name]] for k in node.raxis}
k.var.name: all_steps[k.var.name][rstep_id[k.var.name]] for k in node.raxis
}
score = 0 score = 0
shape = node.propagate_inputs_on_reduction(td.get_tile(node), rstep=rstep) shape = node.propagate_inputs_on_reduction(td.get_tile(node), rstep=rstep)
input_buffers = node.block_analyzer.get_input_buffers(node.reduction_block) input_buffers = node.block_analyzer.get_input_buffers(node.reduction_block)
...@@ -153,18 +148,13 @@ class TensorCorePolicy(DefaultPolicy): ...@@ -153,18 +148,13 @@ class TensorCorePolicy(DefaultPolicy):
return None return None
return max(candidates, key=lambda x: x[1])[0] return max(candidates, key=lambda x: x[1])[0]
cur_rstep_id = { cur_rstep_id = {k.var.name: all_steps[k.var.name].index(rstep[k.var.name]) for k in node.raxis}
k.var.name: all_steps[k.var.name].index(rstep[k.var.name]) for k in node.raxis
}
new_rstep_map = rstep_map.copy() new_rstep_map = rstep_map.copy()
while True: while True:
new_rstep_id = _enlarge(cur_rstep_id) new_rstep_id = _enlarge(cur_rstep_id)
if new_rstep_id is None: if new_rstep_id is None:
break break
new_rstep_map = { new_rstep_map = {k.var.name: all_steps[k.var.name][new_rstep_id[k.var.name]] for k in node.raxis}
k.var.name: all_steps[k.var.name][new_rstep_id[k.var.name]]
for k in node.raxis
}
old_rstep_map = td.rstep_map old_rstep_map = td.rstep_map
td.rstep_map = new_rstep_map td.rstep_map = new_rstep_map
smem_usage, _ = _shared_memory_usage(td) smem_usage, _ = _shared_memory_usage(td)
...@@ -173,9 +163,7 @@ class TensorCorePolicy(DefaultPolicy): ...@@ -173,9 +163,7 @@ class TensorCorePolicy(DefaultPolicy):
break break
else: else:
cur_rstep_id = new_rstep_id cur_rstep_id = new_rstep_id
rstep = { rstep = {k.var.name: all_steps[k.var.name][cur_rstep_id[k.var.name]] for k in node.raxis}
k.var.name: all_steps[k.var.name][cur_rstep_id[k.var.name]] for k in node.raxis
}
return rstep return rstep
for node in self.ordered_nodes: for node in self.ordered_nodes:
...@@ -206,11 +194,7 @@ class TensorCorePolicy(DefaultPolicy): ...@@ -206,11 +194,7 @@ class TensorCorePolicy(DefaultPolicy):
return super().get_node_reduce_step_candidates(node) return super().get_node_reduce_step_candidates(node)
else: else:
# must be a a multiple of wmma_k # must be a a multiple of wmma_k
return { return {k.var.name: [x * self.wmma_k for x in get_all_factors(int(k.dom.extent) // self.wmma_k)] for k in node.raxis}
k.var.name: [
x * self.wmma_k for x in get_all_factors(int(k.dom.extent) // self.wmma_k)
] for k in node.raxis
}
def check_tile_shape_isvalid(self, td: TileDict): def check_tile_shape_isvalid(self, td: TileDict):
for node in self.ordered_nodes: for node in self.ordered_nodes:
...@@ -221,10 +205,7 @@ class TensorCorePolicy(DefaultPolicy): ...@@ -221,10 +205,7 @@ class TensorCorePolicy(DefaultPolicy):
td.tile_map[node][ax_n], td.tile_map[node][ax_n],
) )
# check the tile size is valid # check the tile size is valid
wmma_invalid = [ wmma_invalid = [block_m < wmma_m or block_n < wmma_n for wmma_m, wmma_n in self.arch.get_avaliable_tensorintrin_shapes()]
block_m < wmma_m or block_n < wmma_n
for wmma_m, wmma_n in self.arch.get_avaliable_tensorintrin_shapes()
]
if all(wmma_invalid): if all(wmma_invalid):
return False return False
if any([y % x for x, y in zip(td.tile_map[node], node.get_space_dim())]): if any([y % x for x, y in zip(td.tile_map[node], node.get_space_dim())]):
...@@ -242,13 +223,10 @@ class TensorCorePolicy(DefaultPolicy): ...@@ -242,13 +223,10 @@ class TensorCorePolicy(DefaultPolicy):
return super().compute_node_stride_map(node, td) return super().compute_node_stride_map(node, td)
use_layout = self._can_implement_layout(node, td) use_layout = self._can_implement_layout(node, td)
AS_stride, BS_stride, C_stride = self._compute_tc_strides(node, td.get_tile(node), AS_stride, BS_stride, C_stride = self._compute_tc_strides(node, td.get_tile(node), td.get_rstep(node))
td.get_rstep(node))
A_stride, B_stride, _ = self._compute_tc_strides(node, td.get_tile(node)) A_stride, B_stride, _ = self._compute_tc_strides(node, td.get_tile(node))
tensor_strides = {} tensor_strides = {}
output_strides = { output_strides = {int(i + len(node.input_buffers)): Stride() for i, _ in enumerate(node.output_buffers)}
int(i + len(node.input_buffers)): Stride() for i, _ in enumerate(node.output_buffers)
}
tensor_strides = {} tensor_strides = {}
# when connected to shared input, should use full stride without rstep # when connected to shared input, should use full stride without rstep
for i, (_, _) in enumerate(zip([AS_stride, BS_stride], [A_stride, B_stride])): for i, (_, _) in enumerate(zip([AS_stride, BS_stride], [A_stride, B_stride])):
...@@ -347,8 +325,7 @@ class TensorCorePolicy(DefaultPolicy): ...@@ -347,8 +325,7 @@ class TensorCorePolicy(DefaultPolicy):
overall_gmem_size_in_bytes: int = 0 overall_gmem_size_in_bytes: int = 0
for node in self.ordered_nodes: for node in self.ordered_nodes:
for buffer in node.input_buffers: for buffer in node.input_buffers:
overall_gmem_size_in_bytes += ( overall_gmem_size_in_bytes += int(np.prod(buffer.shape)) * tvm.DataType(buffer.dtype).bits // 8
int(np.prod(buffer.shape)) * tvm.DataType(buffer.dtype).bits // 8)
return overall_gmem_size_in_bytes < self.arch.l2_cache_size_bytes return overall_gmem_size_in_bytes < self.arch.l2_cache_size_bytes
conditions.append(_check_memory_size()) conditions.append(_check_memory_size())
......
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
class Rasterization: class Rasterization:
panel_width_ = None panel_width_ = None
def __init__(self) -> None: def __init__(self) -> None:
...@@ -18,7 +17,6 @@ class Rasterization: ...@@ -18,7 +17,6 @@ class Rasterization:
class NoRasterization(Rasterization): class NoRasterization(Rasterization):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
......
...@@ -4,9 +4,7 @@ from tvm import arith ...@@ -4,9 +4,7 @@ from tvm import arith
class Statement: class Statement:
def __init__(self, output: str, dependent_region: dict, var_map: OrderedDict, range_map: OrderedDict):
def __init__(self, output: str, dependent_region: dict, var_map: OrderedDict,
range_map: OrderedDict):
self.output = output self.output = output
self.dependent_region = dependent_region self.dependent_region = dependent_region
self.var_map = var_map self.var_map = var_map
...@@ -18,7 +16,6 @@ def _merge_two_bounds(x: arith.ConstIntBound, y: arith.ConstIntBound): ...@@ -18,7 +16,6 @@ def _merge_two_bounds(x: arith.ConstIntBound, y: arith.ConstIntBound):
class InputShapeInference: class InputShapeInference:
def __init__(self, deps: list[Statement]): def __init__(self, deps: list[Statement]):
self.deps = deps self.deps = deps
......
...@@ -5,7 +5,6 @@ from tvm import arith, tir ...@@ -5,7 +5,6 @@ from tvm import arith, tir
class Statement: class Statement:
def __init__(self, block_analyzer, block: BlockRV): def __init__(self, block_analyzer, block: BlockRV):
self.block_analyzer = block_analyzer self.block_analyzer = block_analyzer
self.block = block self.block = block
...@@ -21,9 +20,7 @@ class Statement: ...@@ -21,9 +20,7 @@ class Statement:
if len(self.dependent_region[input_name]) != 1: if len(self.dependent_region[input_name]) != 1:
return None return None
indices = self.dependent_region[input_name][0] indices = self.dependent_region[input_name][0]
iter_map_range = { iter_map_range = {_iter.var: _iter.dom for _iter in self.block_analyzer.get_spatial_axis(self.block)}
_iter.var: _iter.dom for _iter in self.block_analyzer.get_spatial_axis(self.block)
}
iter_map_result = arith.detect_iter_map( iter_map_result = arith.detect_iter_map(
indices, indices,
iter_map_range, iter_map_range,
...@@ -77,7 +74,6 @@ class TensorDepNode: ...@@ -77,7 +74,6 @@ class TensorDepNode:
class DependencyAnalysis: class DependencyAnalysis:
def __init__(self, deps): def __init__(self, deps):
self.deps = deps self.deps = deps
# issue: duplicate name when we have two same ops. # issue: duplicate name when we have two same ops.
...@@ -112,8 +108,7 @@ class DependencyAnalysis: ...@@ -112,8 +108,7 @@ class DependencyAnalysis:
def traverse_dependencies(self, compute): def traverse_dependencies(self, compute):
if isinstance(compute, Statement): if isinstance(compute, Statement):
node = self.get_or_create_node( node = self.get_or_create_node(compute.block_analyzer.get_output_buffers(compute.block)[0].name)
compute.block_analyzer.get_output_buffers(compute.block)[0].name)
# Loop through input tensors # Loop through input tensors
for input_buffer in compute.block_analyzer.get_input_buffers(compute.block): for input_buffer in compute.block_analyzer.get_input_buffers(compute.block):
# Get the input node # Get the input node
...@@ -167,7 +162,6 @@ class DependencyAnalysis: ...@@ -167,7 +162,6 @@ class DependencyAnalysis:
class InputShapeInference: class InputShapeInference:
def __init__(self, deps: list[Statement]): def __init__(self, deps: list[Statement]):
self.deps = deps self.deps = deps
self.target_mapping = {} self.target_mapping = {}
...@@ -183,16 +177,11 @@ class InputShapeInference: ...@@ -183,16 +177,11 @@ class InputShapeInference:
if targets in self.target_mapping: if targets in self.target_mapping:
return self.target_mapping[targets] return self.target_mapping[targets]
# should be buffer name instead of block name # should be buffer name instead of block name
name2dep = { name2dep = {dep.block_analyzer.get_output_buffers(dep.block)[0].name: dep for dep in self.deps}
dep.block_analyzer.get_output_buffers(dep.block)[0].name: dep for dep in self.deps
}
mapping = {} mapping = {}
input_vars = [] input_vars = []
for target in targets: for target in targets:
vars = [ vars = [iter.var for iter in name2dep[target].block_analyzer.get_spatial_axis(name2dep[target].block)]
iter.var
for iter in name2dep[target].block_analyzer.get_spatial_axis(name2dep[target].block)
]
input_vars.append(vars) input_vars.append(vars)
mapping[target] = [vars] mapping[target] = [vars]
ana = arith.Analyzer() ana = arith.Analyzer()
...@@ -221,13 +210,8 @@ class InputShapeInference: ...@@ -221,13 +210,8 @@ class InputShapeInference:
mapping[input_name] = [] mapping[input_name] = []
for indices in indices_list: for indices in indices_list:
for region in regions: for region in regions:
vmap = { vmap = {k: (tir.Cast(k.dtype, v) if v.dtype != k.dtype else v) for k, v in zip(ax_vars, indices)}
k: (tir.Cast(k.dtype, v) if v.dtype != k.dtype else v) region = [ana.simplify(tir.stmt_functor.substitute(ax, vmap)) for ax in region]
for k, v in zip(ax_vars, indices)
}
region = [
ana.simplify(tir.stmt_functor.substitute(ax, vmap)) for ax in region
]
if not region_exist_in_list(region, mapping[input_name]): if not region_exist_in_list(region, mapping[input_name]):
mapping[input_name].append(region) mapping[input_name].append(region)
buffers = [] buffers = []
...@@ -241,10 +225,7 @@ class InputShapeInference: ...@@ -241,10 +225,7 @@ class InputShapeInference:
self.target_mapping[targets] = input_vars, mapping self.target_mapping[targets] = input_vars, mapping
return input_vars, mapping return input_vars, mapping
def infer(self, def infer(self, shape: dict[str, list[arith.ConstIntBound]], rstep: dict[str, int] = None, targets=None):
shape: dict[str, list[arith.ConstIntBound]],
rstep: dict[str, int] = None,
targets=None):
if rstep is None: if rstep is None:
rstep = {} rstep = {}
compute_targets = tuple(shape.keys()) compute_targets = tuple(shape.keys())
...@@ -258,8 +239,7 @@ class InputShapeInference: ...@@ -258,8 +239,7 @@ class InputShapeInference:
for ax in self.reduce_axes: for ax in self.reduce_axes:
# assume the dom.min is always 0, maybe we can extend the IterInfo to include the min value. # assume the dom.min is always 0, maybe we can extend the IterInfo to include the min value.
if ax.var.name in rstep: if ax.var.name in rstep:
bound = arith.ConstIntBound( bound = arith.ConstIntBound(int(ax.dom.min), int(ax.dom.min + min(ax.dom.extent, rstep[ax.var.name]) - 1))
int(ax.dom.min), int(ax.dom.min + min(ax.dom.extent, rstep[ax.var.name]) - 1))
else: else:
bound = arith.ConstIntBound(int(ax.dom.min), int(ax.dom.min + ax.dom.extent - 1)) bound = arith.ConstIntBound(int(ax.dom.min), int(ax.dom.min + ax.dom.extent - 1))
ana.update(ax.var, bound, True) ana.update(ax.var, bound, True)
...@@ -312,14 +292,11 @@ class InputShapeInference: ...@@ -312,14 +292,11 @@ class InputShapeInference:
for name, regions in mapping.items(): for name, regions in mapping.items():
region = regions[0] region = regions[0]
result[name] = [ result[name] = [ana.simplify(tir.stmt_functor.substitute(index, vmap)) for index in region]
ana.simplify(tir.stmt_functor.substitute(index, vmap)) for index in region
]
return result return result
def region_exist_in_list(a, list) -> bool: def region_exist_in_list(a, list) -> bool:
def expr_is_same(a, b) -> bool: def expr_is_same(a, b) -> bool:
if isinstance(a, tir.IntImm) and isinstance(b, tir.IntImm): if isinstance(a, tir.IntImm) and isinstance(b, tir.IntImm):
return a.value == b.value return a.value == b.value
......
...@@ -2,7 +2,12 @@ ...@@ -2,7 +2,12 @@
from abc import ABC, abstractmethod # For defining abstract base classes from abc import ABC, abstractmethod # For defining abstract base classes
from dataclasses import dataclass, field # For defining data classes from dataclasses import dataclass, field # For defining data classes
from ..arch import ( # Import architecture-related utilities and classes from ..arch import ( # Import architecture-related utilities and classes
TileDevice, is_volta_arch, is_ampere_arch, is_cdna_arch, auto_infer_current_arch) TileDevice,
is_volta_arch,
is_ampere_arch,
is_cdna_arch,
auto_infer_current_arch,
)
from ..roller.hint import Hint # Import the Hint class from ..roller.hint import Hint # Import the Hint class
from ..roller.node import OutputNode # Import the OutputNode class from ..roller.node import OutputNode # Import the OutputNode class
from tvm.tir import PrimFunc # Import PrimFunc for handling tensor IR functions from tvm.tir import PrimFunc # Import PrimFunc for handling tensor IR functions
...@@ -41,7 +46,7 @@ class BaseTemplate(ABC): ...@@ -41,7 +46,7 @@ class BaseTemplate(ABC):
""" """
pass pass
def with_arch(self, arch: TileDevice) -> 'BaseTemplate': def with_arch(self, arch: TileDevice) -> "BaseTemplate":
""" """
Sets the architecture for this template and returns itself. Sets the architecture for this template and returns itself.
...@@ -109,7 +114,7 @@ class BaseTemplate(ABC): ...@@ -109,7 +114,7 @@ class BaseTemplate(ABC):
""" """
raise NotImplementedError("initialize_function is not implemented") raise NotImplementedError("initialize_function is not implemented")
def set_function(self, func: PrimFunc) -> 'BaseTemplate': def set_function(self, func: PrimFunc) -> "BaseTemplate":
""" """
Sets the function for this template and returns itself. Sets the function for this template and returns itself.
...@@ -122,7 +127,7 @@ class BaseTemplate(ABC): ...@@ -122,7 +127,7 @@ class BaseTemplate(ABC):
self._func = func self._func = func
return self return self
def set_output_nodes(self, output_nodes: list[OutputNode]) -> 'BaseTemplate': def set_output_nodes(self, output_nodes: list[OutputNode]) -> "BaseTemplate":
""" """
Sets the output nodes for this template and returns itself. Sets the output nodes for this template and returns itself.
......
...@@ -28,6 +28,7 @@ class ConvTemplate(BaseTemplate): ...@@ -28,6 +28,7 @@ class ConvTemplate(BaseTemplate):
accum_dtype (str): Data type used for accumulation. accum_dtype (str): Data type used for accumulation.
with_bias (bool): Whether to add a bias term. with_bias (bool): Whether to add a bias term.
""" """
# Operation-related configuration parameters # Operation-related configuration parameters
N: int # The number of input samples processed simultaneously in a batch. N: int # The number of input samples processed simultaneously in a batch.
C: int # The number of input feature maps. C: int # The number of input feature maps.
...@@ -69,12 +70,18 @@ class ConvTemplate(BaseTemplate): ...@@ -69,12 +70,18 @@ class ConvTemplate(BaseTemplate):
AssertionError: If N, C, H, W, F, K, S, D, P are not positive integers. AssertionError: If N, C, H, W, F, K, S, D, P are not positive integers.
""" """
N, C, H, W, F, K, S, D, P = self.N, self.C, self.H, self.W, self.F, self.K, self.S, self.D, self.P N, C, H, W, F, K, S, D, P = self.N, self.C, self.H, self.W, self.F, self.K, self.S, self.D, self.P
assert (isinstance(N, int) and isinstance(C, int) and isinstance(H, int) and assert (
isinstance(W, int) and isinstance(F, int) and isinstance(K, int) and isinstance(N, int)
isinstance(S, int) and isinstance(D, int) and and isinstance(C, int)
isinstance(P, int)), "Only Support Integer Params" and isinstance(H, int)
assert (N > 0 and C > 0 and H > 0 and W > 0 and F > 0 and K > 0 and S > 0 and D > 0 and and isinstance(W, int)
P > 0), "Params should be positive" and isinstance(F, int)
and isinstance(K, int)
and isinstance(S, int)
and isinstance(D, int)
and isinstance(P, int)
), "Only Support Integer Params"
assert N > 0 and C > 0 and H > 0 and W > 0 and F > 0 and K > 0 and S > 0 and D > 0 and P > 0, "Params should be positive"
# Load configuration parameters # Load configuration parameters
in_dtype, out_dtype, accum_dtype = self.in_dtype, self.out_dtype, self.accum_dtype in_dtype, out_dtype, accum_dtype = self.in_dtype, self.out_dtype, self.accum_dtype
...@@ -123,8 +130,10 @@ class ConvTemplate(BaseTemplate): ...@@ -123,8 +130,10 @@ class ConvTemplate(BaseTemplate):
te.if_then_else( te.if_then_else(
te.all(h_in >= 0, h_in < H, w_in >= 0, w_in < W), te.all(h_in >= 0, h_in < H, w_in >= 0, w_in < W),
A[n, h_in, w_in, c].astype(accum_dtype) * B[kh, kw, c, f].astype(accum_dtype), A[n, h_in, w_in, c].astype(accum_dtype) * B[kh, kw, c, f].astype(accum_dtype),
tir.const(0, accum_dtype)), tir.const(0, accum_dtype),
axis=[kh, kw, c]) ),
axis=[kh, kw, c],
)
# Compute convolution result # Compute convolution result
C = te.compute( C = te.compute(
......
...@@ -9,7 +9,6 @@ from ..utils import get_roller_hints_from_output_nodes, get_tensorized_func_and_ ...@@ -9,7 +9,6 @@ from ..utils import get_roller_hints_from_output_nodes, get_tensorized_func_and_
@dataclass @dataclass
class FlashAttentionTemplate(BaseTemplate): class FlashAttentionTemplate(BaseTemplate):
_output_nodes: list[OutputNode] = None _output_nodes: list[OutputNode] = None
# Operation-related configuration parameters # Operation-related configuration parameters
...@@ -91,10 +90,7 @@ class FlashAttentionTemplate(BaseTemplate): ...@@ -91,10 +90,7 @@ class FlashAttentionTemplate(BaseTemplate):
""" """
A_indices = [b, i, k] A_indices = [b, i, k]
B_indices = [b, j, k] B_indices = [b, j, k]
return te.sum( return te.sum(A[tuple(A_indices)].astype(accum_dtype) * B[tuple(B_indices)].astype(accum_dtype), axis=k)
A[tuple(A_indices)].astype(accum_dtype) *
B[tuple(B_indices)].astype(accum_dtype),
axis=k)
# Compute matrix multiplication result # Compute matrix multiplication result
C = te.compute( C = te.compute(
......
...@@ -50,9 +50,8 @@ class GEMVTemplate(BaseTemplate): ...@@ -50,9 +50,8 @@ class GEMVTemplate(BaseTemplate):
N, K = self.N, self.K N, K = self.N, self.K
# Ensure M, N, K are valid positive integers # Ensure M, N, K are valid positive integers
assert (isinstance(M, int) and isinstance(N, int) and assert isinstance(M, int) and isinstance(N, int) and isinstance(K, int), "Only Support Integer M, N, K"
isinstance(K, int)), "Only Support Integer M, N, K" assert M > 0 and N > 0 and K > 0, "M, N, K should be positive"
assert (M > 0 and N > 0 and K > 0), "M, N, K should be positive"
# Load configuration parameters # Load configuration parameters
trans_B = self.trans_B trans_B = self.trans_B
...@@ -86,9 +85,7 @@ class GEMVTemplate(BaseTemplate): ...@@ -86,9 +85,7 @@ class GEMVTemplate(BaseTemplate):
""" """
A_indices = [i, k] A_indices = [i, k]
B_indices = [k, j] if not trans_B else [j, k] B_indices = [k, j] if not trans_B else [j, k]
return te.sum( return te.sum(A[tuple(A_indices)].astype(accum_dtype) * B[tuple(B_indices)].astype(accum_dtype), axis=k)
A[tuple(A_indices)].astype(accum_dtype) * B[tuple(B_indices)].astype(accum_dtype),
axis=k)
# Compute matrix multiplication result # Compute matrix multiplication result
C = te.compute( C = te.compute(
......
...@@ -9,15 +9,13 @@ from ..utils import get_roller_hints_from_func ...@@ -9,15 +9,13 @@ from ..utils import get_roller_hints_from_func
@dataclass @dataclass
class GeneralReductionTemplate(BaseTemplate): class GeneralReductionTemplate(BaseTemplate):
# OP Related Config # OP Related Config
structure: str | list[str] = None structure: str | list[str] = None
shape: list[int] = None shape: list[int] = None
dtype: str = "float16" dtype: str = "float16"
def get_hardware_aware_configs(self, arch: TileDevice = None, topk: int = 10) -> list[Hint]: def get_hardware_aware_configs(self, arch: TileDevice = None, topk: int = 10) -> list[Hint]:
roller_hints = get_roller_hints_from_func( roller_hints = get_roller_hints_from_func(self._func, arch=arch, topk=topk, allow_gemv=False)
self._func, arch=arch, topk=topk, allow_gemv=False)
return roller_hints return roller_hints
def initialize_function(self) -> None: def initialize_function(self) -> None:
...@@ -38,9 +36,9 @@ class GeneralReductionTemplate(BaseTemplate): ...@@ -38,9 +36,9 @@ class GeneralReductionTemplate(BaseTemplate):
spatial_axes = [] spatial_axes = []
reduce_axes = [] reduce_axes = []
for i, axis_type in enumerate(self.structure): for i, axis_type in enumerate(self.structure):
if axis_type.upper() == 'S': if axis_type.upper() == "S":
spatial_axes.append((i, self.shape[i])) spatial_axes.append((i, self.shape[i]))
elif axis_type.upper() == 'R': elif axis_type.upper() == "R":
reduce_axes.append((i, self.shape[i])) reduce_axes.append((i, self.shape[i]))
else: else:
raise ValueError(f"Unrecognized axis type '{axis_type}', only 'S'/'R' allowed.") raise ValueError(f"Unrecognized axis type '{axis_type}', only 'S'/'R' allowed.")
...@@ -90,7 +88,7 @@ class GeneralReductionTemplate(BaseTemplate): ...@@ -90,7 +88,7 @@ class GeneralReductionTemplate(BaseTemplate):
# Walk through the structure in order # Walk through the structure in order
for axis_type in self.structure: for axis_type in self.structure:
if axis_type.upper() == 'S': if axis_type.upper() == "S":
# use the next spatial_indices item # use the next spatial_indices item
full_index.append(spatial_indices[spatial_iter]) full_index.append(spatial_indices[spatial_iter])
spatial_iter += 1 spatial_iter += 1
......
...@@ -65,9 +65,8 @@ class MatmulTemplate(BaseTemplate): ...@@ -65,9 +65,8 @@ class MatmulTemplate(BaseTemplate):
M, N, K = self.M, self.N, self.K M, N, K = self.M, self.N, self.K
# Ensure M, N, K are valid positive integers # Ensure M, N, K are valid positive integers
assert (isinstance(M, int) and isinstance(N, int) and assert isinstance(M, int) and isinstance(N, int) and isinstance(K, int), "Only Support Integer M, N, K"
isinstance(K, int)), "Only Support Integer M, N, K" assert M > 0 and N > 0 and K > 0, "M, N, K should be positive"
assert (M > 0 and N > 0 and K > 0), "M, N, K should be positive"
# Load configuration parameters # Load configuration parameters
trans_A, trans_B = self.trans_A, self.trans_B trans_A, trans_B = self.trans_A, self.trans_B
...@@ -101,9 +100,7 @@ class MatmulTemplate(BaseTemplate): ...@@ -101,9 +100,7 @@ class MatmulTemplate(BaseTemplate):
""" """
A_indices = [i, k] if not trans_A else [k, i] # Adjust indexing if A is transposed A_indices = [i, k] if not trans_A else [k, i] # Adjust indexing if A is transposed
B_indices = [k, j] if not trans_B else [j, k] # Adjust indexing if B is transposed B_indices = [k, j] if not trans_B else [j, k] # Adjust indexing if B is transposed
return te.sum( return te.sum(A[tuple(A_indices)].astype(accum_dtype) * B[tuple(B_indices)].astype(accum_dtype), axis=k)
A[tuple(A_indices)].astype(accum_dtype) * B[tuple(B_indices)].astype(accum_dtype),
axis=k)
# Compute matrix multiplication result # Compute matrix multiplication result
C = te.compute( C = te.compute(
......
...@@ -26,11 +26,9 @@ def get_rasterization_code(pannel_width: int = 8) -> str: ...@@ -26,11 +26,9 @@ def get_rasterization_code(pannel_width: int = 8) -> str:
""" """
def get_roller_hints_from_func(func_or_module: tir.PrimFunc | IRModule, def get_roller_hints_from_func(
arch: TileDevice, func_or_module: tir.PrimFunc | IRModule, arch: TileDevice, topk: int = 10, tensorcore_only: bool = False, allow_gemv: bool = False
topk: int = 10, ) -> list[Hint] | None:
tensorcore_only: bool = False,
allow_gemv: bool = False) -> list[Hint] | None:
func = None func = None
if isinstance(func_or_module, tir.PrimFunc): if isinstance(func_or_module, tir.PrimFunc):
func = func_or_module func = func_or_module
...@@ -44,8 +42,7 @@ def get_roller_hints_from_func(func_or_module: tir.PrimFunc | IRModule, ...@@ -44,8 +42,7 @@ def get_roller_hints_from_func(func_or_module: tir.PrimFunc | IRModule,
roller_hints = None roller_hints = None
if tensorcore_only: if tensorcore_only:
try: try:
tensorized_func, tags = get_tensorized_func_and_tags( tensorized_func, tags = get_tensorized_func_and_tags(func, arch.target, allow_gemv=allow_gemv)
func, arch.target, allow_gemv=allow_gemv)
except Exception as e_msg: except Exception as e_msg:
logger.debug("Get tensorized func and tags failed: ", e_msg) logger.debug("Get tensorized func and tags failed: ", e_msg)
tags = None tags = None
...@@ -58,8 +55,7 @@ def get_roller_hints_from_func(func_or_module: tir.PrimFunc | IRModule, ...@@ -58,8 +55,7 @@ def get_roller_hints_from_func(func_or_module: tir.PrimFunc | IRModule,
policy = DefaultPolicy.from_prim_func(func=func, arch=arch) policy = DefaultPolicy.from_prim_func(func=func, arch=arch)
tensorized_func = None tensorized_func = None
try: try:
tensorized_func, tags = get_tensorized_func_and_tags( tensorized_func, tags = get_tensorized_func_and_tags(func, arch.target, allow_gemv=allow_gemv)
func, arch.target, allow_gemv=allow_gemv)
except Exception as e_msg: except Exception as e_msg:
logger.debug("Get tensorized func and tags failed: ", e_msg) logger.debug("Get tensorized func and tags failed: ", e_msg)
tags = None tags = None
...@@ -69,10 +65,9 @@ def get_roller_hints_from_func(func_or_module: tir.PrimFunc | IRModule, ...@@ -69,10 +65,9 @@ def get_roller_hints_from_func(func_or_module: tir.PrimFunc | IRModule,
return roller_hints return roller_hints
def get_roller_hints_from_output_nodes(output_nodes: list[OutputNode], def get_roller_hints_from_output_nodes(
arch: TileDevice, output_nodes: list[OutputNode], arch: TileDevice, topk: int = 10, extra_tags: list[str] | None = None
topk: int = 10, ) -> list[Hint] | None:
extra_tags: list[str] | None = None) -> list[Hint] | None:
assert isinstance(output_nodes, list), "The input should be a list of functions." assert isinstance(output_nodes, list), "The input should be a list of functions."
lints = [] lints = []
...@@ -80,8 +75,7 @@ def get_roller_hints_from_output_nodes(output_nodes: list[OutputNode], ...@@ -80,8 +75,7 @@ def get_roller_hints_from_output_nodes(output_nodes: list[OutputNode],
policy = TensorCorePolicy.from_output_nodes(output_nodes, arch=arch, tags=None) policy = TensorCorePolicy.from_output_nodes(output_nodes, arch=arch, tags=None)
lints = policy.emit_config(topk) lints = policy.emit_config(topk)
except Exception as e_msg: except Exception as e_msg:
logger.debug(f"Generate hints from output nodes failed: {e_msg}", logger.debug(f"Generate hints from output nodes failed: {e_msg}", "fallback to default policy")
"fallback to default policy")
if len(lints) == 0: if len(lints) == 0:
policy = DefaultPolicy.from_output_nodes(output_nodes, arch=arch, tags=None) policy = DefaultPolicy.from_output_nodes(output_nodes, arch=arch, tags=None)
...@@ -92,7 +86,6 @@ def get_roller_hints_from_output_nodes(output_nodes: list[OutputNode], ...@@ -92,7 +86,6 @@ def get_roller_hints_from_output_nodes(output_nodes: list[OutputNode],
def retrieve_func_from_module(ir_module: IRModule) -> PrimFunc: def retrieve_func_from_module(ir_module: IRModule) -> PrimFunc:
if not isinstance(ir_module, IRModule): if not isinstance(ir_module, IRModule):
raise ValueError("Not supported type: ", type(ir_module)) raise ValueError("Not supported type: ", type(ir_module))
assert len(ir_module.get_global_vars()) == 1, ( assert len(ir_module.get_global_vars()) == 1, "The optimized module should only have one global variable for default schedule."
"The optimized module should only have one global variable for default schedule.")
func = list(ir_module.functions.values())[0] func = list(ir_module.functions.values())[0]
return func return func
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
"""Util to invoke C/C++ compilers in the system.""" """Util to invoke C/C++ compilers in the system."""
import functools import functools
import os import os
import shutil import shutil
...@@ -30,8 +31,7 @@ from tvm.contrib import utils as _utils ...@@ -30,8 +31,7 @@ from tvm.contrib import utils as _utils
def _is_linux_like(): def _is_linux_like():
return (sys.platform == "darwin" or sys.platform.startswith("linux") or return sys.platform == "darwin" or sys.platform.startswith("linux") or sys.platform.startswith("freebsd")
sys.platform.startswith("freebsd"))
def _is_windows_like(): def _is_windows_like():
...@@ -90,7 +90,7 @@ def get_cplus_compiler(): ...@@ -90,7 +90,7 @@ def get_cplus_compiler():
def is_darwin(): def is_darwin():
return platform.system() == 'Darwin' return platform.system() == "Darwin"
def create_shared(output, objects, options=None, cc=None, cwd=None, ccache_env=None): def create_shared(output, objects, options=None, cc=None, cwd=None, ccache_env=None):
...@@ -287,11 +287,7 @@ create_shared.output_format = "so" if sys.platform != "win32" else "dll" ...@@ -287,11 +287,7 @@ create_shared.output_format = "so" if sys.platform != "win32" else "dll"
create_shared.get_target_triple = get_target_by_dump_machine(os.environ.get("CXX", get_cc())) create_shared.get_target_triple = get_target_by_dump_machine(os.environ.get("CXX", get_cc()))
def cross_compiler(compile_func, def cross_compiler(compile_func, options=None, output_format=None, get_target_triple=None, add_files=None):
options=None,
output_format=None,
get_target_triple=None,
add_files=None):
"""Create a cross compiler function by specializing compile_func with options. """Create a cross compiler function by specializing compile_func with options.
This function can be used to construct compile functions that This function can be used to construct compile functions that
...@@ -363,13 +359,7 @@ def cross_compiler(compile_func, ...@@ -363,13 +359,7 @@ def cross_compiler(compile_func,
return _fcompile return _fcompile
def _linux_compile(output, def _linux_compile(output, objects, options, compile_cmd, cwd=None, ccache_env=None, compile_shared=False):
objects,
options,
compile_cmd,
cwd=None,
ccache_env=None,
compile_shared=False):
cmd = [compile_cmd] cmd = [compile_cmd]
if compile_cmd != "nvcc": if compile_cmd != "nvcc":
if compile_shared or output.endswith(".so") or output.endswith(".dylib"): if compile_shared or output.endswith(".so") or output.endswith(".dylib"):
...@@ -430,15 +420,15 @@ def _windows_compile(output, objects, options, cwd=None, ccache_env=None): ...@@ -430,15 +420,15 @@ def _windows_compile(output, objects, options, cwd=None, ccache_env=None):
raise ValueError("ccache not found") raise ValueError("ccache not found")
try: try:
proc = subprocess.Popen( proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, cwd=cwd, env=env)
cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, cwd=cwd, env=env)
(out, _) = proc.communicate() (out, _) = proc.communicate()
except FileNotFoundError: except FileNotFoundError:
raise RuntimeError("Can not find the LLVM clang for Windows clang.exe)." raise RuntimeError(
"Make sure it's installed" "Can not find the LLVM clang for Windows clang.exe)."
" and the installation directory is in the %PATH% environment " "Make sure it's installed"
"variable. Prebuilt binaries can be found at: https://llvm.org/") \ " and the installation directory is in the %PATH% environment "
from None "variable. Prebuilt binaries can be found at: https://llvm.org/"
) from None
if proc.returncode != 0: if proc.returncode != 0:
msg = "Compilation error:\n" msg = "Compilation error:\n"
msg += " ".join(cmd) + "\n" msg += " ".join(cmd) + "\n"
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
"""Wrapping functions to bridge frameworks with DLPack support to TVM""" """Wrapping functions to bridge frameworks with DLPack support to TVM"""
from tvm import runtime from tvm import runtime
...@@ -45,12 +46,8 @@ def convert_func(tvm_func, tensor_type, to_dlpack_func): ...@@ -45,12 +46,8 @@ def convert_func(tvm_func, tensor_type, to_dlpack_func):
def adapt_tensor(arg): def adapt_tensor(arg):
if isinstance(arg, tensor_type): if isinstance(arg, tensor_type):
if arg.dtype in { if arg.dtype in {torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz}:
torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, return runtime.from_dlpack(to_dlpack_func(arg.view(torch.int8)))._create_view(arg.shape, dtype=float8_dtype_map[arg.dtype])
torch.float8_e5m2fnuz
}:
return runtime.from_dlpack(to_dlpack_func(arg.view(torch.int8)))._create_view(
arg.shape, dtype=float8_dtype_map[arg.dtype])
return runtime.from_dlpack(to_dlpack_func(arg)) return runtime.from_dlpack(to_dlpack_func(arg))
return arg return arg
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment