Unverified Commit 29051439 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Lint] Phaseout Yapf format and embrace ruff format (#1417)

parent e84b24bc
......@@ -8,7 +8,6 @@ def is_metal_arch(arch: TileDevice) -> bool:
class METAL(TileDevice):
def __init__(self, target: Target | str):
if isinstance(target, str):
target = Target(target)
......@@ -16,6 +15,6 @@ class METAL(TileDevice):
__all__ = [
'is_metal_arch',
'METAL',
"is_metal_arch",
"METAL",
]
......@@ -19,6 +19,7 @@
# Modifications Copyright (c) Microsoft.
# The code below is mostly copied from apache/tvm common_schedules.py in dlight.
"""Common schedule strategies for TIR."""
from typing import Callable
from tvm import tir
......
# pylint: disable=missing-docstring, invalid-name
"""A GEMM schedule rule for GPU operators."""
from __future__ import annotations
from dataclasses import dataclass
from enum import Enum
......@@ -157,8 +158,7 @@ def find_last_producer_from_buffer(sch, main_block, buffer: tir.Buffer) -> Block
return block
def find_arg_idx_from_buffer_chain(sch: tir.Schedule, main_block: tir.schedule.BlockRV,
buffer: tir.Buffer) -> int:
def find_arg_idx_from_buffer_chain(sch: tir.Schedule, main_block: tir.schedule.BlockRV, buffer: tir.Buffer) -> int:
"""traverse to find the arg index from the buffer"""
producers = sch.get_producers(main_block)
......@@ -226,9 +226,7 @@ def make_iter_fusion_index_map(
else:
fused_iters[trait.kind] = v_i
final_indices: list[tir.PrimExpr] = [
fused_iters.get(kind, tir.IntImm(traits[0].extent.dtype, 0)) for kind in kind_order
]
final_indices: list[tir.PrimExpr] = [fused_iters.get(kind, tir.IntImm(traits[0].extent.dtype, 0)) for kind in kind_order]
return tir.IndexMap(input_iters, final_indices, None)
......@@ -307,8 +305,7 @@ def detect_iter_traits(block: tir.Block) -> tuple[list[IterTrait]] | None:
return A_traits, B_traits, C_traits, block_traits
def get_index_map(block: tir.Block,
layout: list[str] | None = None) -> tuple[tir.IndexMap, ...] | None:
def get_index_map(block: tir.Block, layout: list[str] | None = None) -> tuple[tir.IndexMap, ...] | None:
"""Get index maps for the block
Parameters
......@@ -343,10 +340,7 @@ def get_index_map(block: tir.Block,
return axes
def is_common_reduce(var: Var) -> bool:
for iter_var in block.iter_vars:
if iter_var.var == var and iter_var.iter_type == IterVar.CommReduce:
return True
return False
return any(iter_var.var == var and iter_var.iter_type == IterVar.CommReduce for iter_var in block.iter_vars)
def has_common_reduce(var: Var) -> bool:
vars = collect_vars_from_expr(var)
......@@ -384,17 +378,17 @@ def get_index_map(block: tir.Block,
if kind == "C":
return [IterKind.kIter_S, primary_iter, secondary_iter]
else:
return ([IterKind.kIter_S, spatial_iter, reduction_iter] if check_last_trait(region)
else [IterKind.kIter_S, reduction_iter, spatial_iter])
return (
[IterKind.kIter_S, spatial_iter, reduction_iter]
if check_last_trait(region)
else [IterKind.kIter_S, reduction_iter, spatial_iter]
)
else:
raise ValueError(f"Unknown layout {layout}")
A_index_map = make_iter_fusion_index_map(
A_traits, infer_layout(layout[0], block.reads[0].region, kind="A"))
B_index_map = make_iter_fusion_index_map(
B_traits, infer_layout(layout[1], block.reads[1].region, kind="B"))
C_index_map = make_iter_fusion_index_map(
C_traits, infer_layout(layout[2], block.writes[0].region, kind="C"))
A_index_map = make_iter_fusion_index_map(A_traits, infer_layout(layout[0], block.reads[0].region, kind="A"))
B_index_map = make_iter_fusion_index_map(B_traits, infer_layout(layout[1], block.reads[1].region, kind="B"))
C_index_map = make_iter_fusion_index_map(C_traits, infer_layout(layout[2], block.writes[0].region, kind="C"))
matmul_index_map = make_iter_fusion_index_map(
block_traits,
......@@ -429,8 +423,7 @@ def get_dequantize_block(sch, blocks) -> BlockRV | None:
has_uint_input = any("uint" in str(region.buffer.dtype) for region in block_stmt.reads)
if not has_uint_input:
return False
return not (len(block_stmt.writes) != 1 or
"float" not in str(block_stmt.writes[0].buffer.dtype))
return not (len(block_stmt.writes) != 1 or "float" not in str(block_stmt.writes[0].buffer.dtype))
dequantize_blocks = [block for block in blocks if is_dequantize(block)]
return dequantize_blocks[0] if len(dequantize_blocks) == 1 else None
......@@ -452,8 +445,7 @@ def is_identity_or_transpose_block(block_stmt: tir.Block) -> bool:
return None
axes.extend(undefined_vars(r.min))
# remove trivial axis
trivial_vars = set(
iter_var.var for iter_var in block_stmt.iter_vars if _is_one(iter_var.dom.extent))
trivial_vars = set(iter_var.var for iter_var in block_stmt.iter_vars if _is_one(iter_var.dom.extent))
axes = [axis for axis in axes if axis not in trivial_vars]
# remove duplicate axis
axes = [var for i, var in enumerate(axes) if i == 0 or var != axes[i - 1]]
......@@ -462,8 +454,7 @@ def is_identity_or_transpose_block(block_stmt: tir.Block) -> bool:
lhs_access_vars = get_access_vars(block_stmt.reads[0].region)[-2:]
rhs_access_vars = get_access_vars(block_stmt.writes[0].region)[-2:]
is_identity = list(lhs_access_vars) == list(rhs_access_vars)
is_transpose = list(lhs_access_vars) != list(rhs_access_vars) and set(lhs_access_vars) == set(
rhs_access_vars)
is_transpose = list(lhs_access_vars) != list(rhs_access_vars) and set(lhs_access_vars) == set(rhs_access_vars)
return is_identity, is_transpose
......@@ -491,9 +482,7 @@ def inline_transpose_block(sch: tir.Schedule, blocks: list[tir.schedule.BlockRV]
return result_blocks
def normalize_to_matmul(sch: tir.Schedule,
main_block: BlockRV,
layout: list[str] | None = None) -> tir.Schedule | None:
def normalize_to_matmul(sch: tir.Schedule, main_block: BlockRV, layout: list[str] | None = None) -> tir.Schedule | None:
if layout is None:
layout = ["n", "t", "n"]
block_stmt = sch.get(main_block)
......@@ -543,10 +532,7 @@ def get_tensorized_func_and_tags(
conditions = []
conditions.append(len(block_stmt.reads) == 2)
conditions.append(len(block_stmt.writes) == 1)
conditions.append(
len(
collect_block_iter_vars_used_in_access_region(block_stmt,
block_stmt.writes[0].region)) > 0)
conditions.append(len(collect_block_iter_vars_used_in_access_region(block_stmt, block_stmt.writes[0].region)) > 0)
return all(conditions)
# step2. transform function to tensorcore matmul (e.g. conv2d with im2col)
......@@ -592,10 +578,7 @@ def get_tensorized_func_and_tags(
return axes
def is_common_reduce(var: Var) -> bool:
for iter_var in block_stmt.iter_vars:
if iter_var.var == var and iter_var.iter_type == IterVar.CommReduce:
return True
return False
return any(iter_var.var == var and iter_var.iter_type == IterVar.CommReduce for iter_var in block_stmt.iter_vars)
def has_common_reduce(var: Var) -> bool:
vars = collect_vars_from_expr(var)
......@@ -626,7 +609,7 @@ def get_tensorized_func_and_tags(
# When the func is a dequantize like ops, we should consider the M
require_block_reduce = False
# And we only support float16 for now
if (hasattr(func.attrs, "dequantize_info") and in_dtype in ["bfloat16", "float16"]):
if hasattr(func.attrs, "dequantize_info") and in_dtype in ["bfloat16", "float16"]:
for arg in func.params:
inp_shape = func.buffer_map[arg].shape
M = inp_shape[0]
......@@ -645,9 +628,7 @@ def get_tensorized_func_and_tags(
if target.kind.name == "cuda" and check_sm_version(target.arch) >= 70:
in_dtype, out_dtype = get_in_out_dtypes(block_stmt)
if not is_tensorcore_supported_precision(in_dtype, out_dtype, arch=get_arch(target)):
logger.debug(
f"The input and output dtype ({in_dtype}, {out_dtype})is not supported by tensorcore"
)
logger.debug(f"The input and output dtype ({in_dtype}, {out_dtype})is not supported by tensorcore")
return func, None
# reindex and transform functions
......@@ -676,7 +657,7 @@ def get_tensorized_func_and_tags(
else:
raise ValueError(f"Unknown IterVar type {iter_type}")
if (isinstance(extent, tir.expr.IntImm) and extent.value < minimal_tensorize_threshold):
if isinstance(extent, tir.expr.IntImm) and extent.value < minimal_tensorize_threshold:
return func, None
tags = analysis_tensorcore_tags(sch, main_block, target)
return sch.mod["main"], tags
......@@ -686,8 +667,10 @@ def get_tensorized_func_and_tags(
def get_propagate_map(trans: bool = True, dtype="float16", matrix_name="A", index_dtype="int32"):
from bitblas.tl.mma_layout import ( # pylint: disable=import-outside-toplevel
ldmatrix_32x8_to_shared_16x16_layout, ldmatrix_trans_32x8_to_shared_16x16_layout,
ldmatrix_32x16_to_shared_16x32_layout_a, ldmatrix_32x16_to_shared_16x32_layout_b,
ldmatrix_32x8_to_shared_16x16_layout,
ldmatrix_trans_32x8_to_shared_16x16_layout,
ldmatrix_32x16_to_shared_16x32_layout_a,
ldmatrix_32x16_to_shared_16x32_layout_b,
)
assert dtype in [
......@@ -727,9 +710,7 @@ def get_propagate_map(trans: bool = True, dtype="float16", matrix_name="A", inde
return ldmatrix_layout(thread_id, local_id)
if dtype in ["bfloat16", "float16"]:
ldmatrix_index_map = (
ldmatrix_trans_permutation_16x16_32x8_16x16
if trans else ldmatrix_permutation_16x16_32x8_16x16)
ldmatrix_index_map = ldmatrix_trans_permutation_16x16_32x8_16x16 if trans else ldmatrix_permutation_16x16_32x8_16x16
else:
ldmatrix_index_map = ldmatrix_permutation_16x32_32x16_32x16
......@@ -744,7 +725,6 @@ def get_propagate_map(trans: bool = True, dtype="float16", matrix_name="A", inde
# Ladder weight propagation, which can be used to avoid the ldmatrix
# Instructions.
def get_ladder_stage3_map(dtype="float16", index_dtype="int32"):
def shared_32x8_to_mma_32x8_layout(i, j):
thread_id = (i % 8) * 4 + (j // 2)
local_id = (i // 8) * 2 + (j % 2)
......@@ -837,8 +817,7 @@ def layout_propagate_chain(
scaling_factor = 1
for i, j in zip(write.buffer.shape, read.buffer.shape):
scaling_factor *= i // j
final_indices = list(
index_map.map_indices(tmp_index_map.map_indices(write_indices)))
final_indices = list(index_map.map_indices(tmp_index_map.map_indices(write_indices)))
final_indices[-1] = final_indices[-1] // scaling_factor
index_map = IndexMap(
write_indices,
......
......@@ -2,7 +2,6 @@
class Block:
def __init__(self, start, end, is_free):
self.start = start
self.end = end
......@@ -21,7 +20,6 @@ class Block:
class BestFit:
def __init__(self, align=32):
self.limit = 0
self.list = []
......@@ -31,16 +29,14 @@ class BestFit:
size = (size + self.align - 1) // self.align * self.align
found = None
for block in self.list:
if block.is_free and block.size() >= size and (not found or
found.size() > block.size()):
if block.is_free and block.size() >= size and (not found or found.size() > block.size()):
found = block
if found:
found.is_free = False
remain = found.size() - size
if remain != 0:
found.end -= remain
self.list.insert(
self.list.index(found) + 1, Block(found.end, found.end + remain, True))
self.list.insert(self.list.index(found) + 1, Block(found.end, found.end + remain, True))
return found
elif len(self.list) > 0 and self.list[-1].is_free:
add = size - self.list[-1].size()
......
"""Hint definition for schedule"""
from tvm import DataType
from . import PrimFuncNode
import numpy as np
......@@ -60,7 +61,7 @@ class Stride:
strided_elem = original_shape
else:
assert self.ax < len(shape)
strided_elem = np.prod(shape[0:self.ax + 1]) * self.stride
strided_elem = np.prod(shape[0 : self.ax + 1]) * self.stride
assert strided_elem >= original_shape
return int(strided_elem)
......@@ -217,7 +218,7 @@ class Hint:
return dic
@classmethod
def from_dict(cls, dic: dict) -> 'Hint':
def from_dict(cls, dic: dict) -> "Hint":
hint = cls()
for k, v in dic.items():
setattr(hint, k, v)
......
"""PrimFunc Wrapper and Block information Analaysis"""
from __future__ import annotations
import tvm
......@@ -31,7 +32,6 @@ def pre_order_traverse(block_analyzer, blocks, func):
class BlockAnalyzer:
def __init__(self, sch) -> None:
self.sch: tir.Schedule = sch
self.block_infos: list[BlockInfo] = normalize_prim_func(self.sch)
......@@ -92,7 +92,6 @@ class Edge:
class Node:
def __init__(self, tags: dict | None = None, name: str = "Node") -> None:
self.name = name
if tags is None:
......@@ -177,7 +176,6 @@ class Node:
class PlaceHolderNode(Node):
def __init__(self, name=""):
super().__init__(name="PlaceHolder_" + name)
......@@ -189,11 +187,7 @@ class PlaceHolderNode(Node):
class PrimFuncNode(Node):
def __init__(self,
prim_func: PrimFunc,
tags: dict | None = None,
name: str = "PrimFuncNode") -> None:
def __init__(self, prim_func: PrimFunc, tags: dict | None = None, name: str = "PrimFuncNode") -> None:
super().__init__(tags, name=name)
self.prim_func = self._specialize_func(prim_func)
self.sch: tir.Schedule = tir.Schedule(self.prim_func)
......@@ -227,7 +221,7 @@ class PrimFuncNode(Node):
for dst_id, n in enumerate(inputs):
if isinstance(n, Node):
n = (n, 0)
assert (len(n) == 2)
assert len(n) == 2
src_node, src_id = n[0], n[1]
edge = Edge(src_node, self, src_id, dst_id)
self._in_edges.append(edge)
......@@ -338,9 +332,8 @@ class PrimFuncNode(Node):
if rstep is None:
rstep = {}
shape = {
self.block_analyzer.get_output_buffers(block)[0].name: [
tvm.arith.ConstIntBound(0, val - 1) for val in tile
] for block in self.schedule_stages
self.block_analyzer.get_output_buffers(block)[0].name: [tvm.arith.ConstIntBound(0, val - 1) for val in tile]
for block in self.schedule_stages
}
return self.ana.infer(shape, rstep, targets)
......@@ -356,10 +349,7 @@ class PrimFuncNode(Node):
results.append(shapes[arg.name])
continue
# should not exceed original shape
trimmed_shape = [
self.extent_wrapper(i)
for i in list(map(min, zip(shapes[arg.name], self.input_buffers[i].shape)))
]
trimmed_shape = [self.extent_wrapper(i) for i in list(map(min, zip(shapes[arg.name], self.input_buffers[i].shape)))]
results.append(trimmed_shape)
return results
......@@ -380,10 +370,8 @@ class PrimFuncNode(Node):
propagate_shape = shapes[arg.name]
buffer_shape = args[i].shape
if len(buffer_shape) > len(propagate_shape):
buffer_shape = buffer_shape[-len(propagate_shape):]
trimmed_shape = [
self.extent_wrapper(j) for j in list(map(min, zip(propagate_shape, buffer_shape)))
]
buffer_shape = buffer_shape[-len(propagate_shape) :]
trimmed_shape = [self.extent_wrapper(j) for j in list(map(min, zip(propagate_shape, buffer_shape)))]
results.append(trimmed_shape)
return results
......@@ -412,10 +400,7 @@ class PrimFuncNode(Node):
def get_reduce_inputs_dtype(self):
if self.reduction_block is None:
return {}
return {
b.name: tvm.DataType(b.dtype)
for b in self.block_analyzer.get_input_buffers(self.reduction_block)
}
return {b.name: tvm.DataType(b.dtype) for b in self.block_analyzer.get_input_buffers(self.reduction_block)}
@functools.lru_cache
def infer_tensorcore_axis(self) -> tuple[int]:
......@@ -425,8 +410,7 @@ class PrimFuncNode(Node):
C_ax_m, C_ax_n = self.get_tag("tensorcore_config")
wmma_m, wmma_n, wmma_k = [16, 16, 16] # just for testing, any number is ok
output_buffer_shape = (
self.block_analyzer.sch.get(self.reduction_block).writes[0].buffer.shape)
output_buffer_shape = self.block_analyzer.sch.get(self.reduction_block).writes[0].buffer.shape
valid_region = []
for region in output_buffer_shape:
if region.value == 1:
......@@ -438,8 +422,7 @@ class PrimFuncNode(Node):
def get_cl_shapes(c_ax_m, c_ax_n, num_nvalid_regions):
spatial_dim = self.get_space_dim()
assert len(valid_region) == len(
spatial_dim), f" {valid_region} mismatch with {spatial_dim}"
assert len(valid_region) == len(spatial_dim), f" {valid_region} mismatch with {spatial_dim}"
cl_shapes = [1] * len(spatial_dim)
cl_shapes[c_ax_m - num_nvalid_regions] = wmma_m
cl_shapes[c_ax_n - num_nvalid_regions] = wmma_n
......@@ -467,9 +450,11 @@ class PrimFuncNode(Node):
shapes, _ = self.propagate(shape, rstep)
def is_broadcast_pattern(buffer, output_buffer):
return (buffer in self.args and
len(shapes[output_buffer.name]) > len(shapes[buffer.name]) and
np.prod(shapes[output_buffer.name]) > np.prod(shapes[buffer.name]))
return (
buffer in self.args
and len(shapes[output_buffer.name]) > len(shapes[buffer.name])
and np.prod(shapes[output_buffer.name]) > np.prod(shapes[buffer.name])
)
def is_after_reduce_stage(block):
if not self.reduction_block:
......@@ -491,8 +476,8 @@ class PrimFuncNode(Node):
output_buffer = self.block_analyzer.get_output_buffers(block)[0]
for buffer in self.block_analyzer.get_input_buffers(block):
cache = buffer.name not in cached_tensor and (
is_broadcast_pattern(buffer, output_buffer) or
self.block_analyzer.get_block_info(block).is_reduction())
is_broadcast_pattern(buffer, output_buffer) or self.block_analyzer.get_block_info(block).is_reduction()
)
if not cache:
continue
cached_tensor.append(buffer.name)
......@@ -500,8 +485,7 @@ class PrimFuncNode(Node):
continue # cache after reduce op can often reuse buffer in reduce stage
if buffer.name in stride_map:
num_elem = stride_map[buffer.name].compute_elements_from_shape(
shapes[buffer.name])
num_elem = stride_map[buffer.name].compute_elements_from_shape(shapes[buffer.name])
else:
num_elem = np.prod(shapes[buffer.name])
buffer_len = num_elem * int((tvm.DataType(buffer.dtype).bits + 7) // 8)
......@@ -514,7 +498,6 @@ class PrimFuncNode(Node):
class OutputNode(Node):
def __init__(self, node, id=0):
super().__init__(name="OutputNode")
# connect node and output node
......@@ -549,15 +532,16 @@ def topo_order(list_of_nodes) -> list[Node]:
input_ready_count[dst_node] = len(dst_node.inputs)
list_of_nodes.append(dst_node)
input_ready_count[dst_node] -= 1
assert (input_ready_count[dst_node] >= 0)
assert input_ready_count[dst_node] >= 0
if input_ready_count[dst_node] == 0:
ready.append(dst_node)
assert (len(list_of_nodes) == len(output_list))
assert len(list_of_nodes) == len(output_list)
return output_list
def find_topo_sort_priority(output_node_list) -> list[Node]:
import sys
sys.setrecursionlimit(10000)
def topo_sort_get_layer(node, topo_layer):
......@@ -576,9 +560,7 @@ def find_topo_sort_priority(output_node_list) -> list[Node]:
if node in visited:
return
visited.add(node)
ordered_input_nodes = sorted([edge.src_node for edge in node.inputs],
key=lambda n: topo_layer[n],
reverse=True)
ordered_input_nodes = sorted([edge.src_node for edge in node.inputs], key=lambda n: topo_layer[n], reverse=True)
for n in ordered_input_nodes:
topo_sort_dfs(n, visited, topo_order)
topo_order.append(node)
......@@ -591,7 +573,6 @@ def find_topo_sort_priority(output_node_list) -> list[Node]:
def find_topo_sort(output_node_list) -> list[Node]:
def topo_sort_dfs(node, visited, topo_order):
if node in visited:
return
......
"""Policy for cuda core schedule"""
from __future__ import annotations
import functools
import math
......@@ -36,20 +37,14 @@ class DefaultPolicy:
self.rasterization = NoRasterization()
@classmethod
def from_prim_func(cls,
func: tvm.tir.PrimFunc,
arch: TileDevice,
tags: dict | None = None,
name: str = "PrimFuncNode"):
def from_prim_func(cls, func: tvm.tir.PrimFunc, arch: TileDevice, tags: dict | None = None, name: str = "PrimFuncNode"):
return cls(arch, tags)._init_with_prim_func(func, name)
@classmethod
def from_output_nodes(cls, nodes: list[OutputNode], arch: TileDevice, tags: dict | None = None):
return cls(arch, tags)._init_with_output_nodes(nodes)
def _init_with_prim_func(self,
func: tvm.tir.PrimFunc,
name: str = "PrimFuncNode") -> DefaultPolicy:
def _init_with_prim_func(self, func: tvm.tir.PrimFunc, name: str = "PrimFuncNode") -> DefaultPolicy:
if func is not None and isinstance(func, tvm.tir.PrimFunc):
self.func = func
self.prim_func_node = PrimFuncNode(self.func, tags=self.tags, name=name)
......@@ -60,9 +55,7 @@ class DefaultPolicy:
return self
def _init_with_output_nodes(self, output_nodes: list[OutputNode]):
self.ordered_nodes = list(
filter(lambda n: not n.is_placeholder() and not n.is_output(),
find_topo_sort(output_nodes)))
self.ordered_nodes = list(filter(lambda n: not n.is_placeholder() and not n.is_output(), find_topo_sort(output_nodes)))
for node in self.ordered_nodes:
node.update_tags(self.tags)
......@@ -102,13 +95,14 @@ class DefaultPolicy:
def dfs_smem_tile(self, init_tile, rstep_map) -> Iterable[TileDict]:
_steps = [get_all_factors(n) for n in self.output_nodes[0].get_space_dim()]
steps = [step[step.index(t):] for step, t in zip(_steps, init_tile)]
steps = [step[step.index(t) :] for step, t in zip(_steps, init_tile)]
for i in range(len(steps)):
added = list(
filter(
lambda s: s < steps[i][-1] and s > steps[i][0] and s not in steps[i],
[2, 4, 8, 16, 32],
))
)
)
steps[i].extend(added)
steps[i] = sorted(steps[i])
visited_tiles = {}
......@@ -190,10 +184,7 @@ class DefaultPolicy:
"""
tile_map = {}
for node in self.output_nodes:
tile_map[node] = [
tile[i] * node.get_space_dim()[i] // self.output_nodes[0].get_space_dim()[i]
for i in range(len(tile))
]
tile_map[node] = [tile[i] * node.get_space_dim()[i] // self.output_nodes[0].get_space_dim()[i] for i in range(len(tile))]
return tile_map
def compute_workload_per_item(self, output_tile) -> float:
......@@ -304,8 +295,7 @@ class DefaultPolicy:
score = 0
shape = node.propagate_inputs(tile, rstep=rstep)
for i, input_buffer in enumerate(node.input_buffers):
read_transaction_elements = self.arch.transaction_size[1] // (
(node.get_buffer_dtype(input_buffer).bits + 7) // 8)
read_transaction_elements = self.arch.transaction_size[1] // ((node.get_buffer_dtype(input_buffer).bits + 7) // 8)
score += sim(
int(coalesced_factor(shape[i], input_buffer.shape)),
read_transaction_elements,
......@@ -380,17 +370,13 @@ class DefaultPolicy:
return None
return max(candidates, key=lambda x: x[1])[0]
cur_rstep_id = {
k.var.name: all_steps[k.var.name].index(rstep[k.var.name]) for k in node.raxis
}
cur_rstep_id = {k.var.name: all_steps[k.var.name].index(rstep[k.var.name]) for k in node.raxis}
new_rstep_map = rstep_map.copy()
while True:
new_rstep_id = _enlarge(cur_rstep_id)
if new_rstep_id is None:
break
new_rstep_map[node] = {
k.var.name: all_steps[k.var.name][new_rstep_id[k.var.name]] for k in node.raxis
}
new_rstep_map[node] = {k.var.name: all_steps[k.var.name][new_rstep_id[k.var.name]] for k in node.raxis}
old_rstep_map = td.rstep_map
td.rstep_map = new_rstep_map
smem_usage, _ = self._compute_shared_memory_usage(td)
......@@ -434,15 +420,14 @@ class DefaultPolicy:
if edge.src_node.is_placeholder():
nbytes = (edge.src_node.get_dtype().bits + 7) // 8
read_transaction_elements = self.arch.transaction_size[1] // nbytes
traffic += coalesced_tensor_shape(input_shapes[i], edge.src_node.get_shape(),
read_transaction_elements) * nbytes
traffic += coalesced_tensor_shape(input_shapes[i], edge.src_node.get_shape(), read_transaction_elements) * nbytes
for edge in node.outputs:
if edge.dst_node.is_output():
nbytes = (edge.src_node.get_dtype().bits + 7) // 8
write_transaction_elements = self.arch.transaction_size[0] // nbytes
traffic += coalesced_tensor_shape(output_shapes[edge.src_id],
node.get_shape(edge.src_id),
write_transaction_elements) * nbytes
traffic += (
coalesced_tensor_shape(output_shapes[edge.src_id], node.get_shape(edge.src_id), write_transaction_elements) * nbytes
)
return traffic, op_tile_map
......@@ -487,10 +472,7 @@ class DefaultPolicy:
cached_tensors_map = {}
def can_free(node, out_id):
for edge in node.outputs:
if edge.src_id == out_id and edge.dst_node not in processed:
return False
return True
return all(not (edge.src_id == out_id and edge.dst_node not in processed) for edge in node.outputs)
for node in self.ordered_nodes:
node_internal_bytes, cached_tensors_map[node] = self.infer_node_smem_usage(td, node)
......@@ -528,9 +510,7 @@ class DefaultPolicy:
Tuple[Dict, Dict]
A tuple of dictionaries containing the output strides and tensor strides.
"""
output_strides = {
int(i + len(node.input_buffers)): Stride() for i, _ in enumerate(node.output_buffers)
}
output_strides = {int(i + len(node.input_buffers)): Stride() for i, _ in enumerate(node.output_buffers)}
tensor_strides = {}
return output_strides, tensor_strides
......@@ -551,8 +531,7 @@ class DefaultPolicy:
output_strides_map = {}
tensor_strides_map = {}
for node in self.ordered_nodes:
output_strides_map[node], tensor_strides_map[node] = self.compute_node_stride_map(
node, td)
output_strides_map[node], tensor_strides_map[node] = self.compute_node_stride_map(node, td)
td.output_strides_map, td.tensor_strides_map = output_strides_map, tensor_strides_map
def compute_tile_dict(self, output_tile: list[int], rstep_map) -> TileDict:
......@@ -582,9 +561,7 @@ class DefaultPolicy:
output_shape = self.output_nodes[0].get_space_dim()
td.grid_size = int(np.prod([(y + x - 1) // x for x, y in zip(output_tile, output_shape)]))
# estimated reg usage
reg_usage = int(2 * max([
np.prod(td.get_tile(node)) * node.get_dtype().bits / 32 for node in self.ordered_nodes
]))
reg_usage = int(2 * max([np.prod(td.get_tile(node)) * node.get_dtype().bits / 32 for node in self.ordered_nodes]))
if reg_usage > self.arch.reg_cap:
td.valid = False
return td
......@@ -609,13 +586,10 @@ class DefaultPolicy:
for node in self.ordered_nodes:
if np.prod(td.get_tile(node)) == 0:
return False
node_grid_size = np.prod([
(y + x - 1) // x for x, y in zip(td.get_tile(node), node.get_space_dim())
])
node_grid_size = np.prod([(y + x - 1) // x for x, y in zip(td.get_tile(node), node.get_space_dim())])
if node_grid_size != td.grid_size:
return False
if (hasattr(node, "reduce_op") and node.reduce_op is not None and
len(node.reduce_op.axis) == len(td.output_tile)):
if hasattr(node, "reduce_op") and node.reduce_op is not None and len(node.reduce_op.axis) == len(td.output_tile):
for i, tile_extent in enumerate(td.output_tile):
if node.reduce_op.axis[i].dom.extent % tile_extent:
return False
......@@ -639,23 +613,22 @@ class DefaultPolicy:
node_space_sizes = [int(np.prod(td.get_tile(node))) for node in self.ordered_nodes]
max_block_size = functools.reduce(math.gcd, node_space_sizes)
if max_block_size < self.arch.warp_size * self.arch.sm_partition and max_block_size == min(
node_space_sizes):
node_reduce_sizes = [
int(np.prod(list(td.get_rstep(node).values()))) for node in self.ordered_nodes
]
if max_block_size < self.arch.warp_size * self.arch.sm_partition and max_block_size == min(node_space_sizes):
node_reduce_sizes = [int(np.prod(list(td.get_rstep(node).values()))) for node in self.ordered_nodes]
total_sizes = [x * y for x, y in zip(node_space_sizes, node_reduce_sizes)]
max_possible_size = functools.reduce(math.gcd, total_sizes)
possible_block_sizes = list(
filter(
lambda x: x % max_block_size == 0 and x <= 1024,
get_all_factors(max_possible_size),
))
)
)
possible_block_sizes = list(
filter( # either be a factor of space or cover fully cover the space
lambda x: all([x % s == 0 or s % x == 0 for s in node_space_sizes]),
possible_block_sizes,
))
)
)
factor_ordered = sorted(possible_block_sizes, key=self.score_block_size)
return factor_ordered
else:
......@@ -821,8 +794,7 @@ class DefaultPolicy:
vectorize_result = {}
for tensor, shape in shapes.items():
for v in vectorize_sizes:
if (is_shape_aligned(shape, block_size * v) and is_cont(shape, v) and
is_type_allowed(dtypes[tensor], v)):
if is_shape_aligned(shape, block_size * v) and is_cont(shape, v) and is_type_allowed(dtypes[tensor], v):
vectorize_result[tensor] = v
break
return vectorize_result
......
"""Policy for tensorcore schedule"""
from __future__ import annotations
import tvm
import numpy as np
......@@ -13,7 +14,6 @@ logger = logging.getLogger(__name__)
class TensorCorePolicy(DefaultPolicy):
# this is the trick for wmma.
# However, for int8 mma, the wmma_k should be 32.
wmma_k: int = 16
......@@ -70,9 +70,9 @@ class TensorCorePolicy(DefaultPolicy):
A_high_ax = min(A_ax_m, A_ax_k)
B_high_ax = min(B_ax_n, B_ax_k)
C_high_ax = min(C_ax_m, C_ax_n)
A_stride = Stride(stride=np.prod(AS_shape[A_high_ax + 1:]) + offset, ax=A_high_ax)
B_stride = Stride(stride=np.prod(BS_shape[B_high_ax + 1:]) + offset, ax=B_high_ax)
C_stride = Stride(stride=np.prod(CS_shape[C_high_ax + 1:]) + offset, ax=C_high_ax)
A_stride = Stride(stride=np.prod(AS_shape[A_high_ax + 1 :]) + offset, ax=A_high_ax)
B_stride = Stride(stride=np.prod(BS_shape[B_high_ax + 1 :]) + offset, ax=B_high_ax)
C_stride = Stride(stride=np.prod(CS_shape[C_high_ax + 1 :]) + offset, ax=C_high_ax)
return A_stride, B_stride, C_stride
def infer_node_smem_usage(self, td: TileDict, node: PrimFuncNode):
......@@ -86,8 +86,7 @@ class TensorCorePolicy(DefaultPolicy):
# get reduce input size
target_transaction = self.arch.transaction_size[0] * 2
# 512 bytes // type bits
reduce_input_dtype = node.get_buffer_dtype(
node.block_analyzer.get_input_buffers(node.reduction_block)[0])
reduce_input_dtype = node.get_buffer_dtype(node.block_analyzer.get_input_buffers(node.reduction_block)[0])
basic = (target_transaction * 8) // reduce_input_dtype.bits
result = {}
......@@ -95,7 +94,7 @@ class TensorCorePolicy(DefaultPolicy):
iter_name = iter_info.var.name
iter_dom = iter_info.dom.extent
if iter_dom % 16 > 0:
result[iter_name] = (16 if iter_dom < basic else basic) # for the case of padding
result[iter_name] = 16 if iter_dom < basic else basic # for the case of padding
elif iter_dom % basic == 0:
result[iter_name] = basic
else:
......@@ -114,7 +113,6 @@ class TensorCorePolicy(DefaultPolicy):
return False
if _check_small_tile(td):
smem_limit = min(self.arch.max_smem_usage // td.block_per_SM, self.arch.smem_cap)
rstep_map = td.rstep_map.copy()
......@@ -127,13 +125,10 @@ class TensorCorePolicy(DefaultPolicy):
return rstep
def _shared_memory_usage(td: TileDict):
return node.footprint(td.output_tile, new_rstep_map,
td.tensor_strides_map[node])
return node.footprint(td.output_tile, new_rstep_map, td.tensor_strides_map[node])
def _score(rstep_id):
rstep = {
k.var.name: all_steps[k.var.name][rstep_id[k.var.name]] for k in node.raxis
}
rstep = {k.var.name: all_steps[k.var.name][rstep_id[k.var.name]] for k in node.raxis}
score = 0
shape = node.propagate_inputs_on_reduction(td.get_tile(node), rstep=rstep)
input_buffers = node.block_analyzer.get_input_buffers(node.reduction_block)
......@@ -153,18 +148,13 @@ class TensorCorePolicy(DefaultPolicy):
return None
return max(candidates, key=lambda x: x[1])[0]
cur_rstep_id = {
k.var.name: all_steps[k.var.name].index(rstep[k.var.name]) for k in node.raxis
}
cur_rstep_id = {k.var.name: all_steps[k.var.name].index(rstep[k.var.name]) for k in node.raxis}
new_rstep_map = rstep_map.copy()
while True:
new_rstep_id = _enlarge(cur_rstep_id)
if new_rstep_id is None:
break
new_rstep_map = {
k.var.name: all_steps[k.var.name][new_rstep_id[k.var.name]]
for k in node.raxis
}
new_rstep_map = {k.var.name: all_steps[k.var.name][new_rstep_id[k.var.name]] for k in node.raxis}
old_rstep_map = td.rstep_map
td.rstep_map = new_rstep_map
smem_usage, _ = _shared_memory_usage(td)
......@@ -173,9 +163,7 @@ class TensorCorePolicy(DefaultPolicy):
break
else:
cur_rstep_id = new_rstep_id
rstep = {
k.var.name: all_steps[k.var.name][cur_rstep_id[k.var.name]] for k in node.raxis
}
rstep = {k.var.name: all_steps[k.var.name][cur_rstep_id[k.var.name]] for k in node.raxis}
return rstep
for node in self.ordered_nodes:
......@@ -206,11 +194,7 @@ class TensorCorePolicy(DefaultPolicy):
return super().get_node_reduce_step_candidates(node)
else:
# must be a a multiple of wmma_k
return {
k.var.name: [
x * self.wmma_k for x in get_all_factors(int(k.dom.extent) // self.wmma_k)
] for k in node.raxis
}
return {k.var.name: [x * self.wmma_k for x in get_all_factors(int(k.dom.extent) // self.wmma_k)] for k in node.raxis}
def check_tile_shape_isvalid(self, td: TileDict):
for node in self.ordered_nodes:
......@@ -221,10 +205,7 @@ class TensorCorePolicy(DefaultPolicy):
td.tile_map[node][ax_n],
)
# check the tile size is valid
wmma_invalid = [
block_m < wmma_m or block_n < wmma_n
for wmma_m, wmma_n in self.arch.get_avaliable_tensorintrin_shapes()
]
wmma_invalid = [block_m < wmma_m or block_n < wmma_n for wmma_m, wmma_n in self.arch.get_avaliable_tensorintrin_shapes()]
if all(wmma_invalid):
return False
if any([y % x for x, y in zip(td.tile_map[node], node.get_space_dim())]):
......@@ -242,13 +223,10 @@ class TensorCorePolicy(DefaultPolicy):
return super().compute_node_stride_map(node, td)
use_layout = self._can_implement_layout(node, td)
AS_stride, BS_stride, C_stride = self._compute_tc_strides(node, td.get_tile(node),
td.get_rstep(node))
AS_stride, BS_stride, C_stride = self._compute_tc_strides(node, td.get_tile(node), td.get_rstep(node))
A_stride, B_stride, _ = self._compute_tc_strides(node, td.get_tile(node))
tensor_strides = {}
output_strides = {
int(i + len(node.input_buffers)): Stride() for i, _ in enumerate(node.output_buffers)
}
output_strides = {int(i + len(node.input_buffers)): Stride() for i, _ in enumerate(node.output_buffers)}
tensor_strides = {}
# when connected to shared input, should use full stride without rstep
for i, (_, _) in enumerate(zip([AS_stride, BS_stride], [A_stride, B_stride])):
......@@ -347,8 +325,7 @@ class TensorCorePolicy(DefaultPolicy):
overall_gmem_size_in_bytes: int = 0
for node in self.ordered_nodes:
for buffer in node.input_buffers:
overall_gmem_size_in_bytes += (
int(np.prod(buffer.shape)) * tvm.DataType(buffer.dtype).bits // 8)
overall_gmem_size_in_bytes += int(np.prod(buffer.shape)) * tvm.DataType(buffer.dtype).bits // 8
return overall_gmem_size_in_bytes < self.arch.l2_cache_size_bytes
conditions.append(_check_memory_size())
......
......@@ -2,7 +2,6 @@
class Rasterization:
panel_width_ = None
def __init__(self) -> None:
......@@ -18,7 +17,6 @@ class Rasterization:
class NoRasterization(Rasterization):
def __init__(self) -> None:
super().__init__()
......
......@@ -4,9 +4,7 @@ from tvm import arith
class Statement:
def __init__(self, output: str, dependent_region: dict, var_map: OrderedDict,
range_map: OrderedDict):
def __init__(self, output: str, dependent_region: dict, var_map: OrderedDict, range_map: OrderedDict):
self.output = output
self.dependent_region = dependent_region
self.var_map = var_map
......@@ -18,7 +16,6 @@ def _merge_two_bounds(x: arith.ConstIntBound, y: arith.ConstIntBound):
class InputShapeInference:
def __init__(self, deps: list[Statement]):
self.deps = deps
......
......@@ -5,7 +5,6 @@ from tvm import arith, tir
class Statement:
def __init__(self, block_analyzer, block: BlockRV):
self.block_analyzer = block_analyzer
self.block = block
......@@ -21,9 +20,7 @@ class Statement:
if len(self.dependent_region[input_name]) != 1:
return None
indices = self.dependent_region[input_name][0]
iter_map_range = {
_iter.var: _iter.dom for _iter in self.block_analyzer.get_spatial_axis(self.block)
}
iter_map_range = {_iter.var: _iter.dom for _iter in self.block_analyzer.get_spatial_axis(self.block)}
iter_map_result = arith.detect_iter_map(
indices,
iter_map_range,
......@@ -77,7 +74,6 @@ class TensorDepNode:
class DependencyAnalysis:
def __init__(self, deps):
self.deps = deps
# issue: duplicate name when we have two same ops.
......@@ -112,8 +108,7 @@ class DependencyAnalysis:
def traverse_dependencies(self, compute):
if isinstance(compute, Statement):
node = self.get_or_create_node(
compute.block_analyzer.get_output_buffers(compute.block)[0].name)
node = self.get_or_create_node(compute.block_analyzer.get_output_buffers(compute.block)[0].name)
# Loop through input tensors
for input_buffer in compute.block_analyzer.get_input_buffers(compute.block):
# Get the input node
......@@ -167,7 +162,6 @@ class DependencyAnalysis:
class InputShapeInference:
def __init__(self, deps: list[Statement]):
self.deps = deps
self.target_mapping = {}
......@@ -183,16 +177,11 @@ class InputShapeInference:
if targets in self.target_mapping:
return self.target_mapping[targets]
# should be buffer name instead of block name
name2dep = {
dep.block_analyzer.get_output_buffers(dep.block)[0].name: dep for dep in self.deps
}
name2dep = {dep.block_analyzer.get_output_buffers(dep.block)[0].name: dep for dep in self.deps}
mapping = {}
input_vars = []
for target in targets:
vars = [
iter.var
for iter in name2dep[target].block_analyzer.get_spatial_axis(name2dep[target].block)
]
vars = [iter.var for iter in name2dep[target].block_analyzer.get_spatial_axis(name2dep[target].block)]
input_vars.append(vars)
mapping[target] = [vars]
ana = arith.Analyzer()
......@@ -221,13 +210,8 @@ class InputShapeInference:
mapping[input_name] = []
for indices in indices_list:
for region in regions:
vmap = {
k: (tir.Cast(k.dtype, v) if v.dtype != k.dtype else v)
for k, v in zip(ax_vars, indices)
}
region = [
ana.simplify(tir.stmt_functor.substitute(ax, vmap)) for ax in region
]
vmap = {k: (tir.Cast(k.dtype, v) if v.dtype != k.dtype else v) for k, v in zip(ax_vars, indices)}
region = [ana.simplify(tir.stmt_functor.substitute(ax, vmap)) for ax in region]
if not region_exist_in_list(region, mapping[input_name]):
mapping[input_name].append(region)
buffers = []
......@@ -241,10 +225,7 @@ class InputShapeInference:
self.target_mapping[targets] = input_vars, mapping
return input_vars, mapping
def infer(self,
shape: dict[str, list[arith.ConstIntBound]],
rstep: dict[str, int] = None,
targets=None):
def infer(self, shape: dict[str, list[arith.ConstIntBound]], rstep: dict[str, int] = None, targets=None):
if rstep is None:
rstep = {}
compute_targets = tuple(shape.keys())
......@@ -258,8 +239,7 @@ class InputShapeInference:
for ax in self.reduce_axes:
# assume the dom.min is always 0, maybe we can extend the IterInfo to include the min value.
if ax.var.name in rstep:
bound = arith.ConstIntBound(
int(ax.dom.min), int(ax.dom.min + min(ax.dom.extent, rstep[ax.var.name]) - 1))
bound = arith.ConstIntBound(int(ax.dom.min), int(ax.dom.min + min(ax.dom.extent, rstep[ax.var.name]) - 1))
else:
bound = arith.ConstIntBound(int(ax.dom.min), int(ax.dom.min + ax.dom.extent - 1))
ana.update(ax.var, bound, True)
......@@ -312,14 +292,11 @@ class InputShapeInference:
for name, regions in mapping.items():
region = regions[0]
result[name] = [
ana.simplify(tir.stmt_functor.substitute(index, vmap)) for index in region
]
result[name] = [ana.simplify(tir.stmt_functor.substitute(index, vmap)) for index in region]
return result
def region_exist_in_list(a, list) -> bool:
def expr_is_same(a, b) -> bool:
if isinstance(a, tir.IntImm) and isinstance(b, tir.IntImm):
return a.value == b.value
......
......@@ -2,7 +2,12 @@
from abc import ABC, abstractmethod # For defining abstract base classes
from dataclasses import dataclass, field # For defining data classes
from ..arch import ( # Import architecture-related utilities and classes
TileDevice, is_volta_arch, is_ampere_arch, is_cdna_arch, auto_infer_current_arch)
TileDevice,
is_volta_arch,
is_ampere_arch,
is_cdna_arch,
auto_infer_current_arch,
)
from ..roller.hint import Hint # Import the Hint class
from ..roller.node import OutputNode # Import the OutputNode class
from tvm.tir import PrimFunc # Import PrimFunc for handling tensor IR functions
......@@ -41,7 +46,7 @@ class BaseTemplate(ABC):
"""
pass
def with_arch(self, arch: TileDevice) -> 'BaseTemplate':
def with_arch(self, arch: TileDevice) -> "BaseTemplate":
"""
Sets the architecture for this template and returns itself.
......@@ -109,7 +114,7 @@ class BaseTemplate(ABC):
"""
raise NotImplementedError("initialize_function is not implemented")
def set_function(self, func: PrimFunc) -> 'BaseTemplate':
def set_function(self, func: PrimFunc) -> "BaseTemplate":
"""
Sets the function for this template and returns itself.
......@@ -122,7 +127,7 @@ class BaseTemplate(ABC):
self._func = func
return self
def set_output_nodes(self, output_nodes: list[OutputNode]) -> 'BaseTemplate':
def set_output_nodes(self, output_nodes: list[OutputNode]) -> "BaseTemplate":
"""
Sets the output nodes for this template and returns itself.
......
......@@ -28,6 +28,7 @@ class ConvTemplate(BaseTemplate):
accum_dtype (str): Data type used for accumulation.
with_bias (bool): Whether to add a bias term.
"""
# Operation-related configuration parameters
N: int # The number of input samples processed simultaneously in a batch.
C: int # The number of input feature maps.
......@@ -69,12 +70,18 @@ class ConvTemplate(BaseTemplate):
AssertionError: If N, C, H, W, F, K, S, D, P are not positive integers.
"""
N, C, H, W, F, K, S, D, P = self.N, self.C, self.H, self.W, self.F, self.K, self.S, self.D, self.P
assert (isinstance(N, int) and isinstance(C, int) and isinstance(H, int) and
isinstance(W, int) and isinstance(F, int) and isinstance(K, int) and
isinstance(S, int) and isinstance(D, int) and
isinstance(P, int)), "Only Support Integer Params"
assert (N > 0 and C > 0 and H > 0 and W > 0 and F > 0 and K > 0 and S > 0 and D > 0 and
P > 0), "Params should be positive"
assert (
isinstance(N, int)
and isinstance(C, int)
and isinstance(H, int)
and isinstance(W, int)
and isinstance(F, int)
and isinstance(K, int)
and isinstance(S, int)
and isinstance(D, int)
and isinstance(P, int)
), "Only Support Integer Params"
assert N > 0 and C > 0 and H > 0 and W > 0 and F > 0 and K > 0 and S > 0 and D > 0 and P > 0, "Params should be positive"
# Load configuration parameters
in_dtype, out_dtype, accum_dtype = self.in_dtype, self.out_dtype, self.accum_dtype
......@@ -123,8 +130,10 @@ class ConvTemplate(BaseTemplate):
te.if_then_else(
te.all(h_in >= 0, h_in < H, w_in >= 0, w_in < W),
A[n, h_in, w_in, c].astype(accum_dtype) * B[kh, kw, c, f].astype(accum_dtype),
tir.const(0, accum_dtype)),
axis=[kh, kw, c])
tir.const(0, accum_dtype),
),
axis=[kh, kw, c],
)
# Compute convolution result
C = te.compute(
......
......@@ -9,7 +9,6 @@ from ..utils import get_roller_hints_from_output_nodes, get_tensorized_func_and_
@dataclass
class FlashAttentionTemplate(BaseTemplate):
_output_nodes: list[OutputNode] = None
# Operation-related configuration parameters
......@@ -91,10 +90,7 @@ class FlashAttentionTemplate(BaseTemplate):
"""
A_indices = [b, i, k]
B_indices = [b, j, k]
return te.sum(
A[tuple(A_indices)].astype(accum_dtype) *
B[tuple(B_indices)].astype(accum_dtype),
axis=k)
return te.sum(A[tuple(A_indices)].astype(accum_dtype) * B[tuple(B_indices)].astype(accum_dtype), axis=k)
# Compute matrix multiplication result
C = te.compute(
......
......@@ -50,9 +50,8 @@ class GEMVTemplate(BaseTemplate):
N, K = self.N, self.K
# Ensure M, N, K are valid positive integers
assert (isinstance(M, int) and isinstance(N, int) and
isinstance(K, int)), "Only Support Integer M, N, K"
assert (M > 0 and N > 0 and K > 0), "M, N, K should be positive"
assert isinstance(M, int) and isinstance(N, int) and isinstance(K, int), "Only Support Integer M, N, K"
assert M > 0 and N > 0 and K > 0, "M, N, K should be positive"
# Load configuration parameters
trans_B = self.trans_B
......@@ -86,9 +85,7 @@ class GEMVTemplate(BaseTemplate):
"""
A_indices = [i, k]
B_indices = [k, j] if not trans_B else [j, k]
return te.sum(
A[tuple(A_indices)].astype(accum_dtype) * B[tuple(B_indices)].astype(accum_dtype),
axis=k)
return te.sum(A[tuple(A_indices)].astype(accum_dtype) * B[tuple(B_indices)].astype(accum_dtype), axis=k)
# Compute matrix multiplication result
C = te.compute(
......
......@@ -9,15 +9,13 @@ from ..utils import get_roller_hints_from_func
@dataclass
class GeneralReductionTemplate(BaseTemplate):
# OP Related Config
structure: str | list[str] = None
shape: list[int] = None
dtype: str = "float16"
def get_hardware_aware_configs(self, arch: TileDevice = None, topk: int = 10) -> list[Hint]:
roller_hints = get_roller_hints_from_func(
self._func, arch=arch, topk=topk, allow_gemv=False)
roller_hints = get_roller_hints_from_func(self._func, arch=arch, topk=topk, allow_gemv=False)
return roller_hints
def initialize_function(self) -> None:
......@@ -38,9 +36,9 @@ class GeneralReductionTemplate(BaseTemplate):
spatial_axes = []
reduce_axes = []
for i, axis_type in enumerate(self.structure):
if axis_type.upper() == 'S':
if axis_type.upper() == "S":
spatial_axes.append((i, self.shape[i]))
elif axis_type.upper() == 'R':
elif axis_type.upper() == "R":
reduce_axes.append((i, self.shape[i]))
else:
raise ValueError(f"Unrecognized axis type '{axis_type}', only 'S'/'R' allowed.")
......@@ -90,7 +88,7 @@ class GeneralReductionTemplate(BaseTemplate):
# Walk through the structure in order
for axis_type in self.structure:
if axis_type.upper() == 'S':
if axis_type.upper() == "S":
# use the next spatial_indices item
full_index.append(spatial_indices[spatial_iter])
spatial_iter += 1
......
......@@ -65,9 +65,8 @@ class MatmulTemplate(BaseTemplate):
M, N, K = self.M, self.N, self.K
# Ensure M, N, K are valid positive integers
assert (isinstance(M, int) and isinstance(N, int) and
isinstance(K, int)), "Only Support Integer M, N, K"
assert (M > 0 and N > 0 and K > 0), "M, N, K should be positive"
assert isinstance(M, int) and isinstance(N, int) and isinstance(K, int), "Only Support Integer M, N, K"
assert M > 0 and N > 0 and K > 0, "M, N, K should be positive"
# Load configuration parameters
trans_A, trans_B = self.trans_A, self.trans_B
......@@ -101,9 +100,7 @@ class MatmulTemplate(BaseTemplate):
"""
A_indices = [i, k] if not trans_A else [k, i] # Adjust indexing if A is transposed
B_indices = [k, j] if not trans_B else [j, k] # Adjust indexing if B is transposed
return te.sum(
A[tuple(A_indices)].astype(accum_dtype) * B[tuple(B_indices)].astype(accum_dtype),
axis=k)
return te.sum(A[tuple(A_indices)].astype(accum_dtype) * B[tuple(B_indices)].astype(accum_dtype), axis=k)
# Compute matrix multiplication result
C = te.compute(
......
......@@ -26,11 +26,9 @@ def get_rasterization_code(pannel_width: int = 8) -> str:
"""
def get_roller_hints_from_func(func_or_module: tir.PrimFunc | IRModule,
arch: TileDevice,
topk: int = 10,
tensorcore_only: bool = False,
allow_gemv: bool = False) -> list[Hint] | None:
def get_roller_hints_from_func(
func_or_module: tir.PrimFunc | IRModule, arch: TileDevice, topk: int = 10, tensorcore_only: bool = False, allow_gemv: bool = False
) -> list[Hint] | None:
func = None
if isinstance(func_or_module, tir.PrimFunc):
func = func_or_module
......@@ -44,8 +42,7 @@ def get_roller_hints_from_func(func_or_module: tir.PrimFunc | IRModule,
roller_hints = None
if tensorcore_only:
try:
tensorized_func, tags = get_tensorized_func_and_tags(
func, arch.target, allow_gemv=allow_gemv)
tensorized_func, tags = get_tensorized_func_and_tags(func, arch.target, allow_gemv=allow_gemv)
except Exception as e_msg:
logger.debug("Get tensorized func and tags failed: ", e_msg)
tags = None
......@@ -58,8 +55,7 @@ def get_roller_hints_from_func(func_or_module: tir.PrimFunc | IRModule,
policy = DefaultPolicy.from_prim_func(func=func, arch=arch)
tensorized_func = None
try:
tensorized_func, tags = get_tensorized_func_and_tags(
func, arch.target, allow_gemv=allow_gemv)
tensorized_func, tags = get_tensorized_func_and_tags(func, arch.target, allow_gemv=allow_gemv)
except Exception as e_msg:
logger.debug("Get tensorized func and tags failed: ", e_msg)
tags = None
......@@ -69,10 +65,9 @@ def get_roller_hints_from_func(func_or_module: tir.PrimFunc | IRModule,
return roller_hints
def get_roller_hints_from_output_nodes(output_nodes: list[OutputNode],
arch: TileDevice,
topk: int = 10,
extra_tags: list[str] | None = None) -> list[Hint] | None:
def get_roller_hints_from_output_nodes(
output_nodes: list[OutputNode], arch: TileDevice, topk: int = 10, extra_tags: list[str] | None = None
) -> list[Hint] | None:
assert isinstance(output_nodes, list), "The input should be a list of functions."
lints = []
......@@ -80,8 +75,7 @@ def get_roller_hints_from_output_nodes(output_nodes: list[OutputNode],
policy = TensorCorePolicy.from_output_nodes(output_nodes, arch=arch, tags=None)
lints = policy.emit_config(topk)
except Exception as e_msg:
logger.debug(f"Generate hints from output nodes failed: {e_msg}",
"fallback to default policy")
logger.debug(f"Generate hints from output nodes failed: {e_msg}", "fallback to default policy")
if len(lints) == 0:
policy = DefaultPolicy.from_output_nodes(output_nodes, arch=arch, tags=None)
......@@ -92,7 +86,6 @@ def get_roller_hints_from_output_nodes(output_nodes: list[OutputNode],
def retrieve_func_from_module(ir_module: IRModule) -> PrimFunc:
if not isinstance(ir_module, IRModule):
raise ValueError("Not supported type: ", type(ir_module))
assert len(ir_module.get_global_vars()) == 1, (
"The optimized module should only have one global variable for default schedule.")
assert len(ir_module.get_global_vars()) == 1, "The optimized module should only have one global variable for default schedule."
func = list(ir_module.functions.values())[0]
return func
......@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
"""Util to invoke C/C++ compilers in the system."""
import functools
import os
import shutil
......@@ -30,8 +31,7 @@ from tvm.contrib import utils as _utils
def _is_linux_like():
return (sys.platform == "darwin" or sys.platform.startswith("linux") or
sys.platform.startswith("freebsd"))
return sys.platform == "darwin" or sys.platform.startswith("linux") or sys.platform.startswith("freebsd")
def _is_windows_like():
......@@ -90,7 +90,7 @@ def get_cplus_compiler():
def is_darwin():
return platform.system() == 'Darwin'
return platform.system() == "Darwin"
def create_shared(output, objects, options=None, cc=None, cwd=None, ccache_env=None):
......@@ -287,11 +287,7 @@ create_shared.output_format = "so" if sys.platform != "win32" else "dll"
create_shared.get_target_triple = get_target_by_dump_machine(os.environ.get("CXX", get_cc()))
def cross_compiler(compile_func,
options=None,
output_format=None,
get_target_triple=None,
add_files=None):
def cross_compiler(compile_func, options=None, output_format=None, get_target_triple=None, add_files=None):
"""Create a cross compiler function by specializing compile_func with options.
This function can be used to construct compile functions that
......@@ -363,13 +359,7 @@ def cross_compiler(compile_func,
return _fcompile
def _linux_compile(output,
objects,
options,
compile_cmd,
cwd=None,
ccache_env=None,
compile_shared=False):
def _linux_compile(output, objects, options, compile_cmd, cwd=None, ccache_env=None, compile_shared=False):
cmd = [compile_cmd]
if compile_cmd != "nvcc":
if compile_shared or output.endswith(".so") or output.endswith(".dylib"):
......@@ -430,15 +420,15 @@ def _windows_compile(output, objects, options, cwd=None, ccache_env=None):
raise ValueError("ccache not found")
try:
proc = subprocess.Popen(
cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, cwd=cwd, env=env)
proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, cwd=cwd, env=env)
(out, _) = proc.communicate()
except FileNotFoundError:
raise RuntimeError("Can not find the LLVM clang for Windows clang.exe)."
raise RuntimeError(
"Can not find the LLVM clang for Windows clang.exe)."
"Make sure it's installed"
" and the installation directory is in the %PATH% environment "
"variable. Prebuilt binaries can be found at: https://llvm.org/") \
from None
"variable. Prebuilt binaries can be found at: https://llvm.org/"
) from None
if proc.returncode != 0:
msg = "Compilation error:\n"
msg += " ".join(cmd) + "\n"
......
......@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
"""Wrapping functions to bridge frameworks with DLPack support to TVM"""
from tvm import runtime
......@@ -45,12 +46,8 @@ def convert_func(tvm_func, tensor_type, to_dlpack_func):
def adapt_tensor(arg):
if isinstance(arg, tensor_type):
if arg.dtype in {
torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2,
torch.float8_e5m2fnuz
}:
return runtime.from_dlpack(to_dlpack_func(arg.view(torch.int8)))._create_view(
arg.shape, dtype=float8_dtype_map[arg.dtype])
if arg.dtype in {torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz}:
return runtime.from_dlpack(to_dlpack_func(arg.view(torch.int8)))._create_view(arg.shape, dtype=float8_dtype_map[arg.dtype])
return runtime.from_dlpack(to_dlpack_func(arg))
return arg
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment