Commit cd191889 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Carver] Introduce a tile-structure based cost model for auto tuning (#70)

* [Enhancement] Add VectorizeLoop function and update imports for compatibility

* [CI][Test] Improve test cases for vectorization and fix typos in parser comments

* lint fix

* Fix incorrect module reference for VectorizeLoop transformation

* Refactor vectorize_loop transformation by removing unused extent mutation logic

* [Enhancement] Add support for FP8 data types and global barriers in CUDA codegen

* Fix formatting in CUDA FP8 header file for consistency

* Refactor CI workflow to use 'tilelang_ci' virtual environment and update CUDA type printing for better clarity

* Update submodule 'tvm' to latest commit for improved functionality

* Refactor execution backend references from 'dl_pack' to 'dlpack' for consistency and clarity; add apply_simplify function to simplify PrimFunc or IRModule.

* Refactor CUDA code for improved readability; clean up formatting and remove unnecessary whitespace in multiple files.

* Refactor import statement in test_tilelang_kernel_dequantize_gemm.py to use 'tilelang.language' for consistency

* Add CUDA requirements to FP8 test cases and update references for clarity

* Add a blank line for improved readability in test_tilelang_kernel_fp8_gemm_mma.py

* Fix data type in reference result calculation for consistency in test_tilelang_kernel_gemm_mma_intrinsic.py

* Add CUDA requirements and FP8 test cases for matmul and gemv simulations

* Remove debug print statements and use tilelang's testing assertion for result validation in test_tilelang_kernel_gemm_mma_intrinsic.py

* Remove outdated comment regarding FP8 tests in test_tilelang_kernel_gemv_simt.py

* Add BF16 support to matrix multiplication and introduce corresponding test cases

* Add a blank line for improved readability in BF16 GEMM test

* Update acknowledgements in README to include supervision by Zhi Yang at Peking University

* enhance acknowledgement

* Replace tutorial on memory layout optimization with new tutorial on writing high-performance kernels with thread primitives

* Update subproject commit for TVM dependency

* Update subproject commit for TVM dependency

* Add int4_t type and functions for packing char values in CUDA common header

* Add plot_layout example and implement GetForwardVars method in layout classes

* Refactor code for improved readability by adjusting line breaks and formatting in layout and test files

* Fix formatting by removing unnecessary line break in layout.h

* Refactor make_int4 function for improved readability by adjusting parameter formatting

* Add legend to plot_layout for improved clarity of thread and local IDs

* Remove unnecessary dependencies from requirements files for cleaner setup

* Remove flash_mha.py and add .gitkeep to deepseek_mla directory

* Add build requirements and update installation scripts for improved setup

* Introduce carver

* Refactor imports and improve code formatting for consistency

* Add unit tests for carver recommendation hints

* lint fix

* Enhance ElementwiseTemplate and BaseTemplate with detailed docstrings for improved code documentation and clarity

* Refactor import statements and clean up whitespace in template files for improved readability

* Add README.md for Carver framework with usage examples and architecture support
parent 2411fa28
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import tilelang.testing
from tilelang import carver
from tilelang.carver.arch import auto_infer_current_arch
from typing import List
def run_general_reduction_recommend_hints(structure: str = "SSR",
shape: List[int] = None,
dtype: str = "float16",
topk: int = 20):
arch = auto_infer_current_arch()
carve_template = carver.GeneralReductionTemplate(
structure=structure,
shape=shape,
dtype=dtype,
).with_arch(arch)
func = carve_template.equivalent_function()
assert func is not None, "Function is None"
hints = carve_template.recommend_hints(topk=topk)
assert len(hints) > 0, "Hints length is zero"
def test_general_reduction_recommend_hints():
run_general_reduction_recommend_hints("SSR", [1024, 1024, 1024], "float16")
run_general_reduction_recommend_hints("SS", [1024, 1024], "float16")
run_general_reduction_recommend_hints("SRS", [1024, 1024, 1024], "float16")
def run_elementwise_recommend_hints(shape: List[int] = None,
dtype: str = "float16",
topk: int = 20):
arch = auto_infer_current_arch()
carve_template = carver.ElementwiseTemplate(
shape=shape,
dtype=dtype,
).with_arch(arch)
func = carve_template.equivalent_function()
assert func is not None, "Function is None"
hints = carve_template.recommend_hints(topk=topk)
assert len(hints) > 0, "Hints length is not topk"
def test_elementwise_recommend_hints():
run_elementwise_recommend_hints([1024, 1024], "float16")
run_elementwise_recommend_hints([1024], "float16")
run_elementwise_recommend_hints([1024, 1024, 1024], "float16")
def run_matmul_recommend_hints(
M: int = 1024,
N: int = 1024,
K: int = 1024,
in_dtype: str = "float16",
out_dtype: str = "float16",
accum_dtype: str = "float16",
):
arch = auto_infer_current_arch()
carve_template = carver.MatmulTemplate(
M=M,
N=N,
K=K,
in_dtype=in_dtype,
out_dtype=out_dtype,
accum_dtype=accum_dtype,
).with_arch(arch)
func = carve_template.equivalent_function()
assert func is not None, "Function is None"
hints = carve_template.recommend_hints(topk=20)
assert len(hints) > 0, "Hints length is not 20"
def test_matmul_recommend_hints():
run_matmul_recommend_hints(1024, 1024, 1024, "float16", "float16", "float16")
run_matmul_recommend_hints(1024, 1024, 1024, "int8", "int32", "int32")
run_matmul_recommend_hints(1024, 1024, 1024, "float16", "float32", "float16")
def run_gemv_recommend_hints(N: int = 1024,
K: int = 1024,
in_dtype: str = "float16",
out_dtype: str = "float16",
accum_dtype: str = "float16"):
arch = auto_infer_current_arch()
carve_template = carver.GEMVTemplate(
N=N,
K=K,
in_dtype=in_dtype,
out_dtype=out_dtype,
accum_dtype=accum_dtype,
).with_arch(arch)
func = carve_template.equivalent_function()
assert func is not None, "Function is None"
hints = carve_template.recommend_hints(topk=20)
assert len(hints) > 0, "Hints length is not 20"
def test_gemv_recommend_hints():
run_gemv_recommend_hints(1024, 1024, "float16", "float16", "float16")
run_gemv_recommend_hints(1024, 1024, "int8", "int32", "int32")
run_gemv_recommend_hints(1024, 1024, "float16", "float32", "float16")
if __name__ == "__main__":
tilelang.testing.main()
# Carver: A Tile-Structure Based Hint Recommend Framework for Machine Learning Compilers
**Carver** is a lightweight framework for generating and ranking tile configurations (also known as **tiling strategies**, **blocking schemes**, or **scheduling hints**) for common GPU, CPU, and accelerator backends. It helps you explore efficient mappings of loops for operations such as matrix multiplication, elementwise transforms, and other reduction-oriented kernels.
Carver combines hardware architecture information, user-defined tile structures, and built-in heuristics to recommend tiling strategies (or "hints"). The recommended hints are easily adaptable to multiple backends, including [TVM](https://tvm.apache.org/), [triton](https://github.com/openai/triton), [tilelang](https://github.com/LeiYanggh/tilelang) (or other domain-specific compilers).
---
### Key Features
- **Unified Tiling Framework**: Generate tile candidates for multiple backends under a unified API.
- **Architecture-Specific Modeling**: Take into account architecture constraints (e.g., CUDA `smem_cap`, warp size, CPU cache structure, etc.) when generating hints.
- **Flexible Templates**: High-level templates (like `MatmulTemplate`, `GeneralReductionTemplate`, `ElementwiseTemplate`) let you concisely specify kernel structures.
- **Extendable**: Easily add support for new backends and new operation templates.
---
## Usage Examples
### Basic Usage: General Reduction Template
Once installed tilelang, you can import Carver and start creating templates:
```python
from tilelang import carver
from tilelang.carver.arch import CUDA
# Instantiate a CUDA device object for an RTX 4090
arch = CUDA("nvidia/geforce-rtx-4090")
# Create a general reduction template for a loop nest:
# for i in Spatial(1024):
# for j in Spatial(1024):
# for k in Reduce(1024):
# ...
carve_template = carver.GeneralReductionTemplate(
structure="SSR",
shape=[1024, 1024, 1024],
dtype="float16",
).with_arch(arch)
# Generate top 20 tile candidates (aka scheduling hints)
hints = carve_template.recommend_hints(topk=20)
for hint in hints:
print(hint)
```
**Example Output** (truncated):
```python
{
'block': [1, 128],
'thread': [1, 128],
'rstep': [64],
...
},
{
'block': [2, 64],
'thread': [2, 64],
'rstep': [64],
...
},
...
{
'block': [1, 16],
'thread': [1, 16],
'rstep': [512],
'reduce_thread': [8],
...
}
```
A tile structure composed of S and R can simulate various cases. For example, structure `SS` represents a 2D element-wise operation, while `SSR` can represent a general matrix multiplication.
We can specialize more advanced templates to provide finer-grained information, such as `MatmulTemplate`.
### Matmul Template
Carver also provides a specialized `MatmulTemplate` for matrix multiplication (e.g., `C = A * B`), automatically inferring common tiling strategies (thread blocks, warps, use of tensor cores, etc.).
```python
from tilelang import carver
from tilelang.carver.arch import CUDA
arch = CUDA("nvidia/geforce-rtx-4090")
carve_template = carver.MatmulTemplate(
M=1024,
N=1024,
K=1024,
in_dtype="float16",
accum_dtype="float16",
out_dtype="float16",
).with_arch(arch)
# Retrieve the (symbolic) function describing the matmul
func = carve_template.equivalent_function()
print("Equivalent Function:\n", func)
# Generate hints
hints = carve_template.recommend_hints(topk=20)
for hint in hints:
print(hint)
```
**Example Output**:
```python
{
'block': [32, 64],
'warp': [16, 32],
'rstep': [128],
'use_tc': True,
...
},
{
'block': [64, 32],
'warp': [32, 16],
'rstep': [128],
'use_tc': True,
...
},
...
{
'block': [256, 32],
'warp': [128, 16],
'rstep': [32],
'use_tc': True,
...
}
```
---
## Supported Architectures
Carver currently provides out-of-the-box support for:
- **CUDA**: e.g., `arch = CUDA("nvidia/geforce-rtx-4090")`
- **CDNA** (AMD GPU-like backends)
- **CPU**
Adding a new architecture is as simple as implementing a new subclass of `TileDevice` or providing a custom target that describes:
- Shared/local memory capacity
- Warp (or vector) size
- Cache sizes
- Tensor instructions available
Below is an **illustrative snippet** of the CUDA backend:
```python
class CUDA(TileDevice):
def __init__(self, target: Union[tvm.target.Target, str]):
...
self.platform = "CUDA"
# Device constraints
self.smem_cap = device.max_shared_memory_per_block
self.compute_max_core = device.multi_processor_count
self.warp_size = device.warp_size
...
self.transaction_size = [32, 128] # bytes
self.bandwidth = [750, 12080] # MB/s, approximate
self.available_tensor_instructions = None
def get_avaliable_tensorintrin_shapes(self):
self.available_tensor_instructions = (
TensorInstruction("mma", [16, 16]),
TensorInstruction("wmma", [16, 16]),
)
return [t.shape for t in self.available_tensor_instructions]
def __repr__(self):
return f"CUDA({self.target})"
```
## Adapting Hints to Other Compilers
One of Carver’s main benefits is its adaptability. Here are a examples for triton lang:
Given a Carver hint like:
```python
{
'block': [32, 64],
'warp': [16, 32],
'rstep': [128],
'use_tc': True,
'vectorize': {'A_reindex': 8, 'B_reindex': 8}
}
```
You might interpret this in **Triton** as:
- `block_m = 32, block_n = 64, block_k = 128`
- Potential warp usage = `warp_m = 16, warp_n = 32`
- `vectorize`: load data with a vector width of 8
- If `use_tc` is true, consider using Tensor Cores (TensorOps in Triton) if supported.
This helps quickly test multiple configurations without manually guessing.
## Supported Templates
Carver abstracts common loop patterns through templates:
- **`GeneralReductionTemplate`**: For general `Spatial-Spatial-Reduce` (SSR) structures or similar.
- **`MatmulTemplate`**: For standard matrix multiplication `C = A * B`.
- **`GEMVTemplate`**: For `y = Ax` or `y = xA` style operations.
- **`ElementwiseTemplate`**: For elementwise transformations or pointwise ops.
You can also create your own specialized templates if you have unique loop structures or constraints. For instance, you might define specialized templates for convolution, flash attention, etc.
## TODO Items
- [ ] **Flash Attention** and its variants: Support search-space generation for specialized attention kernels.
- [ ] **Adapt to tile language**: Provide ready-made scheduling calls or wrappers for [tilelang](https://github.com/LeiYanggh/tilelang) to streamline end-to-end integration.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Base infra"""
from .analysis import (
BlockInfo, # noqa: F401
IterInfo, # noqa: F401
collect_block_iter_vars_used_in_access_region, # noqa: F401
collect_vars_used_in_prim_expr, # noqa: F401
detect_dominant_read, # noqa: F401
is_broadcast_epilogue, # noqa: F401
normalize_prim_func, # noqa: F401
) # noqa: F401
from .common_schedules import get_block, get_output_blocks, try_inline, try_inline_contiguous_spatial # noqa: F401
from .roller import *
from .arch import CUDA, CDNA # noqa: F401
from .template import MatmulTemplate, GEMVTemplate, ElementwiseTemplate, GeneralReductionTemplate # noqa: F401
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Analysis on TIR blocks, loops and functions."""
from typing import List, Optional, Set, Union
from typing_extensions import Literal
from tvm import ir, tir, DataType
from tvm._ffi import get_global_func
from tvm.target.target import Target
from tvm.tir import Schedule, IterVar
from tvm.tir.schedule import BlockRV
class IterInfo:
"""Information about a loop/iter var."""
kind: Literal["S", "R", "O"]
var: tir.Var
_dom: tir.PrimExpr
loop_rv: tir.schedule.LoopRV
def __init__(
self,
kind: Literal["S", "R", "O"],
var: tir.Var,
dom: tir.PrimExpr,
loop_rv: tir.schedule.LoopRV,
):
"""Construct an IterInfo object."""
self.kind = kind
self.var = var
self._dom = dom
self.loop_rv = loop_rv
@property
def dom(self) -> Union[int, tir.PrimExpr]:
"""The iteration domain of the loop."""
return int(self._dom) if isinstance(self._dom, tir.IntImm) else self._dom
def __str__(self) -> str:
return f'Iter("{self.kind}", {self.dom})'
def __repr__(self) -> str:
return str(self)
class BlockInfo:
"""Information about a TIR block."""
name: str
iters: List[IterInfo]
block_rv: tir.schedule.BlockRV
_reduction_block: bool
def __init__(
self,
name: str,
iters: List[IterInfo],
block_rv: tir.schedule.BlockRV,
reduction_block: bool = False,
):
"""Construct a BlockInfo object."""
self.name = name
self.block_rv = block_rv
self.iters = iters
self._reduction_block = reduction_block
def dom(self) -> List[Union[int, tir.PrimExpr]]:
"""The iteration domain of the block."""
return [i.dom for i in self.iters]
def dom_kind(self) -> str:
"""The iteration domain kind of the block, for example, SSSS, SSSR."""
return "".join(i.kind for i in self.iters)
def is_injective(self) -> bool:
"""Whether the block is injective, i.e. all its iteration domains are injective."""
return all(k == "S" for k in self.dom_kind())
def is_elementwise(self, sch: tir.Schedule) -> bool:
"""Whether the block is elementwise, i.e. trivial mapping between read/write region"""
def _check_unit_var_range(dom: ir.Range, var: tir.Var) -> bool:
return dom.min.same_as(var) and dom.extent == 1
if not self.is_injective():
return False
block = sch.get(self.block_rv)
if len(block.reads) != 1 or len(block.writes) != 1:
return False
r_region = block.reads[0].region
w_region = block.writes[0].region
if len(r_region) != len(w_region):
return False
for var, r_dom, w_dom in zip(block.iter_vars, r_region, w_region):
if not _check_unit_var_range(var, r_dom) or not _check_unit_var_range(var, w_dom):
return False
return True
def is_reduction(self) -> bool:
"""Whether the block is a reduction workload."""
# TODO(@junrushao): distinguish GEMV and reduction
return self._reduction_block
def is_gemv(self) -> bool:
"""Whether the block is a GEMV workload."""
raise NotImplementedError
def is_gemm(self) -> bool:
"""Whether the block is a GEMM workload."""
raise NotImplementedError
def __str__(self) -> str:
return f'BlockInfo("{self.name}", "{self.dom_kind()}", {self.dom()})'
def __repr__(self) -> str:
return str(self)
_normalize_prim_func = get_global_func("tir.schedule.NormalizePrimFunc")
def normalize_prim_func(sch: tir.Schedule) -> Optional[List[BlockInfo]]:
"""Normalize the primfunc to normal form"""
try:
result = _normalize_prim_func(sch)
if result is None:
return None
except Exception: # pylint: disable=broad-except
return None
def _iter_kind(i: tir.IterVar) -> str:
return {
tir.IterVar.DataPar: "S",
tir.IterVar.CommReduce: "R",
}.get(i.iter_type, "O")
blocks: List[BlockInfo] = []
for block, loops, iters, is_reduction in zip(*result):
blocks.append(
BlockInfo(
name=sch.get(block).name_hint,
iters=[
IterInfo(
kind=_iter_kind(iter), # type: ignore
var=iter.var,
dom=iter.dom,
loop_rv=loop,
) for loop, iter in zip(loops, iters)
],
block_rv=block,
reduction_block=is_reduction,
))
return blocks
def find_var_from_func(func, var: str):
for buffer in func.buffer_map.values():
for i in buffer.shape:
if isinstance(i, tir.Var) and i.name == var:
return i
return None
def check_func_with_dynamic(func):
for buffer in func.buffer_map.values():
for i in buffer.shape:
if isinstance(i, tir.Var):
return True
return False
def _assert_gpu_target(target: Target):
if "gpu" not in target.keys:
raise ValueError(f"Expect a GPU target, but got {target}")
def get_max_threads_per_block(target: Target) -> int:
_assert_gpu_target(target)
max_threads_per_block = None
for name in ["max_threads_per_block", "max_num_threads"]:
if max_threads_per_block is None:
max_threads_per_block = target.attrs.get(name, None)
if max_threads_per_block is None:
max_threads_per_block = 64
return int(max_threads_per_block)
def get_max_shared_memory_per_block(target: Target) -> int:
_assert_gpu_target(target)
max_shared_memory_per_block = target.attrs.get("max_shared_memory_per_block", None)
if max_shared_memory_per_block is None:
raise ValueError(
f"Cannot find `max_shared_memory_per_block` in {target}, please specify it manually")
return int(max_shared_memory_per_block)
def get_root_block(sch: Schedule, func_name: str = "main") -> BlockRV:
try:
block = sch.mod[func_name].body.block
except Exception:
raise ValueError(f"The function body is expected to be the root block, but got:\n"
f"{sch.mod[func_name].body}") from None
return sch.get_block(block.name_hint)
def collect_block_iter_vars_used_in_access_region(block: tir.Block,
region: List[ir.Range]) -> Set[tir.Var]:
"""Collect the block iter variables used in the access region of a buffer region."""
tir_vars = set()
for expr in region:
if expr.extent != 1:
continue
tir_vars |= collect_vars_used_in_prim_expr(expr.min)
tir_vars &= set(iter_var.var for iter_var in block.iter_vars)
return tir_vars
def collect_vars_used_in_prim_expr(expr: tir.PrimExpr) -> Set[tir.Var]:
"""Collect the variables used in the PrimExpr."""
tir_vars = set()
def _collect_tir_var(expr):
if isinstance(expr, tir.Var):
tir_vars.add(expr)
tir.stmt_functor.post_order_visit(expr, _collect_tir_var)
return tir_vars
def detect_dominant_read(block: tir.Block) -> tir.PrimExpr:
"""Detect the dominant read indices in the block."""
dominant_read = None
num_read_iters = -1
for buffer_region in block.reads:
tir_vars = collect_block_iter_vars_used_in_access_region(block, buffer_region.region)
if num_read_iters < len(tir_vars):
num_read_iters = len(tir_vars)
dominant_read = buffer_region
assert dominant_read is not None
(result,) = dominant_read.buffer.offset_of([e.min for e in dominant_read.region])
return result
def is_broadcast_epilogue(
sch: tir.Schedule,
block: tir.schedule.BlockRV,
epilogue: tir.schedule.BlockRV,
) -> bool:
"""Check if the epilogue block is a broadcast pattern"""
write_buffers = {r.buffer for r in sch.get(block).writes}
epilogue_iters = {i.var: i for i in sch.get(epilogue).iter_vars if i.dom != 1}
for buffer_region in sch.get(epilogue).reads:
if buffer_region.buffer not in write_buffers:
continue
tir_vars = collect_block_iter_vars_used_in_access_region(
sch.get(epilogue), buffer_region.region)
if len(tir_vars) < len(epilogue_iters):
return True
return False
def get_reduction_blocks(sch: tir.Schedule,
blocks: List[tir.schedule.BlockRV]) -> List[tir.schedule.BlockRV]:
# Get the main computation block
def is_reduction(block: BlockRV) -> bool:
block_stmt = sch.get(block)
iter_types = {iter_var.iter_type for iter_var in block_stmt.iter_vars}
return iter_types == {IterVar.CommReduce, IterVar.DataPar}
def is_spatial(block: BlockRV) -> bool:
block_stmt = sch.get(block)
iter_types = {iter_var.iter_type for iter_var in block_stmt.iter_vars}
return iter_types == {IterVar.DataPar}
# NOTE: We assume there is only one reduction block in the function
# all blocks are required to be spatial or reduction
if not all([is_reduction(block) or is_spatial(block) for block in blocks]):
return None
# There is only one reduction block
reduction_blocks = [block for block in blocks if is_reduction(block)]
if len(reduction_blocks) == 0:
return None
return reduction_blocks
def get_coalesced_veclen(block_stmt: tir.Block, target_bits: int = 128) -> int:
# gpu memory prefer 128 bits coalesced access (e.g. four banks)
# 128 bits
buffers: List[tir.Buffer] = []
for read in block_stmt.reads:
buffers.append(read.buffer)
for write in block_stmt.writes:
buffers.append(write.buffer)
# pick the dtype with the largest bits
max_dtype_bits: int = 0
for buffer in buffers:
max_dtype_bits = max(max_dtype_bits, DataType(buffer.dtype).bits)
return target_bits // max_dtype_bits
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from .arch_base import TileDevice
from .cuda import CUDA
from .cpu import CPU
from .cdna import CDNA
from typing import Union
from tvm.target import Target
def get_arch(target: Union[str, Target] = "cuda") -> TileDevice:
if isinstance(target, str):
target = Target(target)
if target.kind.name == "cuda":
return CUDA(target)
elif target.kind.name == "llvm":
return CPU(target)
elif target.kind.name == "hip":
return CDNA(target)
else:
raise ValueError(f"Unsupported target: {target.kind.name}")
def auto_infer_current_arch() -> TileDevice:
# TODO(lei): This is a temporary solution to infer the current architecture
# Can be replaced by a more sophisticated method in the future
return get_arch("cuda")
from .cpu import is_cpu_arch # noqa: F401
from .cuda import (
is_cuda_arch, # noqa: F401
is_volta_arch, # noqa: F401
is_ampere_arch, # noqa: F401
is_ada_arch, # noqa: F401
is_hopper_arch, # noqa: F401
is_tensorcore_supported_precision, # noqa: F401
has_mma_support, # noqa: F401
)
from .cdna import is_cdna_arch # noqa: F401
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from typing import List
class TileDevice:
"""
Represents the architecture of a computing device, capturing various hardware specifications.
"""
def __init__(self) -> None:
self.reg_cap: int = 0 # Register capacity: The amount of register memory available
self.smem_cap: int = 0 # Shared memory capacity: The amount of shared memory available
self.compute_max_core: int = 0 # The maximum number of computing cores
self.warp_size: int = (
0 # The size of a warp, a group of threads that execute instructions in lockstep
)
self.sm_partition: int = 0 # The number of streaming multiprocessor partitions
self.transaction_size: List[int] = [
0,
0,
] # The size of memory transactions, typically in bytes
self.max_smem_usage: int = 0 # The maximum shared memory usage allowed
self.bandwidth: List[int] = [
0,
0,
] # Bandwidth specifications, possibly including peak and sustained rates
self.platform: str = "unknown" # The platform or manufacturer of the device
self.compute_capability: str = (
"unknown" # The compute capability, indicating the feature set and performance level
)
self.l2_cache_size_bytes: int = 0
# the number of transaction size in bytes
self.transaction_size: List[int] = [0, 0] # in bytes
# bandwidth in MB/s, will be used for recommend basic tile size
self.bandwidth: List[int] = [0, 0]
def get_avaliable_tensorintrin_shapes(self):
raise NotImplementedError()
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import tvm
from tvm.target import Target
from .arch_base import TileDevice
from typing import List, Union
def is_cdna_arch(arch: TileDevice) -> bool:
return isinstance(arch, CDNA)
class CDNA(TileDevice):
def __init__(self, target: Union[Target, str]):
if isinstance(target, str):
target = tvm.target.Target(target)
self.target = target
device = tvm.runtime.rocm(0)
if not device.exist:
raise RuntimeError("Cannot find HIP device 0.")
self.device: tvm.runtime.Device = device
self.platform: str = "CDNA"
self.smem_cap = device.max_shared_memory_per_block
self.compute_max_core = device.multi_processor_count
self.warp_size = device.warp_size
self.compute_capability = device.compute_version.replace(".", "")
self.reg_cap: int = 32768
self.max_smem_usage: int = 2 * self.smem_cap
self.sm_partition: int = 4
self.l2_cache_size_bytes: int = target.l2_cache_size_bytes
self.transaction_size: List[int] = [32, 128] # in bytes
self.bandwidth: List[int] = [1300, 14000]
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import tvm
from tvm.target import Target
from .arch_base import TileDevice
def is_cpu_arch(arch: TileDevice) -> bool:
return isinstance(arch, CPU)
# For LLVM Backend, we do not provide the detailed information of the CPU
# As the LLVM backend do not required tuning, just maintain the consistency
class CPU(TileDevice):
def __init__(self, target: Target):
self.target = target
device = tvm.runtime.cpu(0)
if not device.exist:
raise RuntimeError("Cannot find cpu device 0.")
self.device: tvm.runtime.Device = device
self.platform: str = "CPU"
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import tvm
from tvm.target import Target
from .arch_base import TileDevice
from typing import List, Union
def check_sm_version(arch: str) -> int:
sm_version = arch.replace("sm_", "")
return int(sm_version) if sm_version.isdigit() else -1
def is_cuda_arch(arch: TileDevice) -> bool:
return isinstance(arch, CUDA)
def is_volta_arch(arch: TileDevice) -> bool:
conditions = [True]
conditions.append(is_cuda_arch(arch))
conditions.append(arch.sm_version >= 70)
conditions.append(arch.sm_version < 80)
return all(conditions)
def is_ampere_arch(arch: TileDevice) -> bool:
conditions = [True]
conditions.append(is_cuda_arch(arch))
conditions.append(arch.sm_version >= 80 and arch.sm_version < 89)
return all(conditions)
def is_ada_arch(arch: TileDevice) -> bool:
conditions = [True]
conditions.append(is_cuda_arch(arch))
conditions.append(arch.sm_version == 89)
return all(conditions)
def is_hopper_arch(arch: TileDevice) -> bool:
conditions = [True]
conditions.append(is_cuda_arch(arch))
conditions.append(arch.sm_version == 90)
return all(conditions)
def has_mma_support(arch: TileDevice) -> bool:
conditions = [True]
conditions.append(is_cuda_arch(arch))
conditions.append(arch.sm_version >= 80)
return all(conditions)
volta_tensorcore_supported = [
("float16", "float32"),
("float16", "float16"),
]
ampere_tensorcore_supported = [
("bfloat16", "float32"),
("float16", "float32"),
("float16", "float16"),
("int8", "int32"),
("int4", "int32"),
("int2", "int32"),
("int1", "int32"),
]
ada_tensorcore_supported = [
("bfloat16", "float32"),
("float16", "float32"),
("float16", "float16"),
("int8", "int32"),
("e5m2_float8", "float32"),
("e4m3_float8", "float32"),
]
hopper_tensorcore_supported = ada_tensorcore_supported
# TODO(lei): we should consider the dtype of the input a and b
# instead of assuming both a and b share the same dtype.
# As the tensorcore may supports e4m3_float8 * e5m2_float8
def is_tensorcore_supported_precision(in_dtype: str, accum_dtype: str, arch: TileDevice) -> bool:
if is_volta_arch(arch):
return (in_dtype, accum_dtype) in volta_tensorcore_supported
elif is_ampere_arch(arch):
return (in_dtype, accum_dtype) in ampere_tensorcore_supported
elif is_ada_arch(arch):
return (in_dtype, accum_dtype) in ada_tensorcore_supported
elif is_hopper_arch(arch):
return (in_dtype, accum_dtype) in hopper_tensorcore_supported
else:
raise ValueError(f"Unsupported architecture: {arch}")
class TensorInstruction(object):
def __init__(
self,
name: str,
shape: List[int],
):
self.name: str = name
# only hold the shape of M and N
self.shape: List[int] = shape
class CUDA(TileDevice):
def __init__(self, target: Union[Target, str]):
if isinstance(target, str):
target = tvm.target.Target(target)
self.target = target
self.sm_version = check_sm_version(self.target.arch)
device = tvm.runtime.cuda(0)
if not device.exist:
raise RuntimeError("Cannot find cuda device 0.")
self.device: tvm.runtime.Device = device
self.platform: str = "CUDA"
self.smem_cap = device.max_shared_memory_per_block
self.compute_max_core = device.multi_processor_count
self.warp_size = device.warp_size
self.compute_capability = device.compute_version.replace(".", "")
self.reg_cap: int = 65536
self.max_smem_usage: int = 2 * self.smem_cap
self.sm_partition: int = 4
self.l2_cache_size_bytes: int = target.l2_cache_size_bytes
# the number of transaction size in bytes
self.transaction_size: List[int] = [32, 128] # in bytes
# bandwidth in MB/s, will be used for recommend basic tile size
# TODO(lei): find some way to get the real bandwidth
# However, the ratio of bandwidth between different devices can
# be similar. The bandwidth can work for another devices as well.
self.bandwidth: List[int] = [750, 12080]
# get the available tensor instructions during runtime to avoid
# the dependency of the tensor intrinsics registration
self.available_tensor_instructions: List[TensorInstruction] = None
def get_avaliable_tensorintrin_shapes(self):
self.available_tensor_instructions = (
TensorInstruction("mma", [16, 16]),
TensorInstruction("wmma", [16, 16]),
)
return [t.shape for t in self.available_tensor_instructions]
def __repr__(self):
return f"CUDA({self.target})"
# Copyright 2018 The apache/tvm Authors. All Rights Reserved.
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
# Modifications Copyright (c) Microsoft.
# The code below is mostly copied from apache/tvm common_schedules.py in dlight.
"""Common schedule strategies for TIR."""
from typing import Callable, List
from tvm import tir
from .utils import retrieve_func_from_module
from .analysis import BlockInfo
def get_block(
sch: tir.Schedule,
blocks: List[BlockInfo],
name: str,
):
"""Get the target block from a schedule.
Parameters
----------
sch : tir.Schedule
The TIR schedule used to get target block.
name : str
The name of the target block.
Returns
-------
target_block : BlockRV
The target block.
"""
target_block: tir.BlockRV = None
for block_info in blocks:
block = block_info.block_rv
if sch.get(block).name_hint == name:
target_block = block
return target_block
def get_output_blocks(
sch: tir.Schedule,
blocks: List[BlockInfo],
):
"""Get the output blocks of a schedule.
Parameters
----------
sch : tir.Schedule
The TIR schedule used to get output blocks.
blocks : List[BlockInfo]
The blocks to be analyzed.
Returns
-------
output_blocks : List[BlockInfo]
The output blocks.
"""
# collect arguments buffer
func = retrieve_func_from_module(sch.mod)
args = list(func.buffer_map.values())
output_blocks = []
for block_info in blocks:
block = block_info.block_rv
for write in sch.get(block).writes:
if write.buffer in args:
output_blocks.append(block)
return output_blocks
def try_inline(
sch: tir.Schedule,
blocks: List[BlockInfo],
) -> List[BlockInfo]:
"""Try to inline as many blocks as possible, and return the remaining blocks.
Parameters
----------
sch : tir.Schedule
The TIR schedule used to inline blocks.
blocks : List[BlockInfo]
The blocks to be inlined.
Returns
-------
remaining : List[BlockInfo]
The remaining blocks that cannot be inlined.
"""
def _trial(func: Callable):
for i, block in enumerate(blocks):
try:
func(block.block_rv)
except Exception: # pylint: disable=bare-except
continue
return i
return None
while True:
i = _trial(sch.compute_inline)
if i is None:
i = _trial(sch.reverse_compute_inline)
if i is None:
break
blocks.pop(i)
return blocks
def try_inline_contiguous_spatial(
sch: tir.Schedule,
block_infos: List[BlockInfo],
) -> List[BlockInfo]:
"""Try to inline contiguous spatial blocks in a schedule
Parameters
----------
sch : tir.Schedule
The TIR schedule used to inline blocks.
block_infos : List[BlockInfo]
The blocks to be try.
Returns
-------
remaining : List[BlockInfo]
The remaining blocks that cannot be inlined.
"""
if block_infos is None:
return None
results = []
spatial_blocks = []
block: BlockInfo
for block in block_infos:
if block.is_injective():
spatial_blocks.append(block)
elif spatial_blocks:
results.extend(try_inline(sch, spatial_blocks))
results.append(block)
spatial_blocks = []
else:
results.append(block)
if spatial_blocks:
results.extend(try_inline(sch, spatial_blocks))
return results
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# pylint: disable=missing-docstring, invalid-name
"""A GEMM schedule rule for GPU operators."""
from dataclasses import dataclass
from enum import Enum
from typing import List, Optional, Set, Union, Tuple, Dict
from tvm import tir
from tvm.ir import Range
from tvm.tir import IterVar, PrimExpr, Var, BufferRegion, IndexMap
from tvm.tir.analysis import undefined_vars
from tvm.tir.schedule.schedule import BlockRV
from .analysis import (
collect_block_iter_vars_used_in_access_region,
get_root_block,
get_reduction_blocks,
)
from tvm.target.target import Target
from tvm.tir.stmt_functor import pre_order_visit
from bitblas.base.arch import get_arch, is_tensorcore_supported_precision
import logging
logger = logging.getLogger(__name__)
def collect_vars_from_expr(prim_expr):
vars = []
def callback(node):
if isinstance(node, Var):
vars.append(node)
return True
pre_order_visit(prim_expr, callback)
return vars
def _is_one(x: PrimExpr) -> bool:
return isinstance(x, tir.IntImm) and x.value == 1
def _collect_producers(sch: tir.Schedule, block: tir.schedule.BlockRV):
result = []
for producer in sch.get_producers(block):
result.append(producer)
result.extend(_collect_producers(sch, producer))
return result
def _collect_consumers(sch: tir.Schedule, block: tir.schedule.BlockRV):
result = []
for consumer in sch.get_consumers(block):
result.append(consumer)
result.extend(_collect_consumers(sch, consumer))
return result
def auto_inline_producers(
sch: tir.Schedule,
block: tir.schedule.BlockRV,
skip_blocks: Optional[List[tir.schedule.BlockRV]] = None,
):
skip_blocks = skip_blocks or []
while True:
inlined_cnt = 0
producers = _collect_producers(sch, block)
for producer in producers:
if any(sch.get(producer) == sch.get(skip_block) for skip_block in skip_blocks):
continue
try:
sch.compute_inline(producer)
inlined_cnt += 1
except Exception: # pylint: disable=bare-except
continue
if inlined_cnt == 0:
return
def auto_inline_consumers(
sch: tir.Schedule,
block: tir.schedule.BlockRV,
):
while True:
inlined_cnt = 0
consumers = _collect_consumers(sch, block)
for consumer in consumers:
try:
sch.compute_inline(consumer)
inlined_cnt += 1
except Exception: # pylint: disable=bare-except
continue
for consumer in consumers:
try:
sch.reverse_compute_inline(consumer)
inlined_cnt += 1
except Exception: # pylint: disable=bare-except
continue
if inlined_cnt == 0:
return
def auto_inline_consumer_chain(
sch: tir.Schedule,
block: tir.schedule.BlockRV,
):
auto_inline_consumers(sch, block)
remaining_consumers = sch.get_consumers(block)
if len(remaining_consumers) != 0:
# Some blocks have failed to be inlined to the producer cache-write stage.
# This could be due to another producer block that has not been scheduled.
for c in remaining_consumers:
for p in sch.get_producers(c):
if sch.get(p) != sch.get(block):
sch.compute_inline(p)
# Try inlining into the cache-write stage again, this time it should succeed.
auto_inline_consumers(sch, block)
# used to match the similar region with dequantize op.
def find_first_similar_region(regions: List[BufferRegion], buffer: tir.Buffer):
for region in regions:
if len(region.buffer.shape) == len(buffer.shape):
return region
return None
# used to match the similar buffer with dequantize op.
def find_first_similar_buffer(regions: List[BufferRegion], buffer: tir.Buffer):
for region in regions:
if len(region.buffer.shape) == len(buffer.shape):
return region.buffer
return None
# find the block that required to be reindex and scope.
def find_last_producer_from_buffer(sch, main_block, buffer: tir.Buffer) -> Optional[BlockRV]:
# block that most near to the arguments
block = main_block
buffer = buffer
while True:
last_buffer = buffer
producers = sch.get_producers(block)
if len(producers) == 0:
# do not have any producer means it is the first block
break
for producer in producers:
for write in sch.get(producer).writes:
if write.buffer == buffer:
block = producer
buffer = find_first_similar_buffer(sch.get(producer).reads, last_buffer)
if buffer == last_buffer:
break
return block
def find_arg_idx_from_buffer_chain(sch: tir.Schedule, main_block: tir.schedule.BlockRV,
buffer: tir.Buffer) -> int:
"""traverse to find the arg index from the buffer"""
producers = sch.get_producers(main_block)
# a head buffer has no producer blocks
def find_args_index(sch: tir.Schedule, buffer: tir.Buffer):
for i, param in enumerate(sch.mod["main"].params):
if sch.mod["main"].buffer_map[param] == buffer:
return i
return None
is_head_buffer = len(producers) == 0
if is_head_buffer:
return find_args_index(sch, buffer)
for block in sch.get_producers(main_block):
if len(sch.get(block).reads) != 1 or len(sch.get(block).writes) != 1:
continue
for write in sch.get(block).writes:
if write.buffer == buffer:
return find_arg_idx_from_buffer_chain(sch, block, buffer)
# if no buffer producer block found, it means the buffer is an input buffer
return find_args_index(sch, buffer)
class IterKind(Enum):
"""Iter kinds for GEMM-liked programs.
We can simplify the computation to C[S, I, J] += A[S, I, K] * B[S, J, K],
where `I, J, K` are fundamental axes for gemm and `S` represents all
other spatial axes (e.g. batches)
kIter_S: spatial axes
kIter_I: I axes
kIter_J: J axes
kIter_K: K axes
kIter_T: trivial axes (i.e. with extent 1)
"""
kIter_S = 0
kIter_I = 1
kIter_J = 2
kIter_K = 3
kIter_T = 4
@dataclass
class IterTrait:
kind: IterKind
extent: PrimExpr
def make_iter_fusion_index_map(
traits: List[IterTrait],
kind_order: List[IterKind],
) -> tir.IndexMap:
fused_iters: Dict[IterKind, PrimExpr] = {}
input_iters: List[tir.Var] = []
for i, trait in enumerate(traits):
v_i = tir.Var(f"i{i}", trait.extent.dtype)
input_iters.append(v_i)
if trait.kind == IterKind.kIter_T:
continue
if trait.kind not in kind_order:
raise ValueError(f"Unknown iter kind {trait.kind}")
if trait.kind in fused_iters:
fused_iters[trait.kind] = fused_iters[trait.kind] * trait.extent + v_i
else:
fused_iters[trait.kind] = v_i
final_indices: List[tir.PrimExpr] = [
fused_iters.get(kind, tir.IntImm(traits[0].extent.dtype, 0)) for kind in kind_order
]
return tir.IndexMap(input_iters, final_indices, None)
def detect_iter_traits(block: tir.Block) -> Optional[Tuple[List[IterTrait]]]:
"""Detect iter traits based on the pattern C[S, I, J] += A[S, I, K] * B[S, J, K]
Parameters
----------
block : tir.Block
The block to be analyzed
Returns
-------
traits : Optional[Tuple[List[IterTrait]]]
The detected iter traits for axes in A, B and C. None if the block
does not match the pattern.
"""
if len(block.reads) != 2 or len(block.writes) != 1:
return None
def get_access_axes(region: List[Range]) -> Set[Var]:
axes: Set[Var] = set()
for r in region:
if not _is_one(r.extent):
raise ValueError("Expect elemwise block access")
axes = axes.union(set(undefined_vars(r.min)))
return axes
try:
A_axes = get_access_axes(block.reads[0].region)
B_axes = get_access_axes(block.reads[1].region)
C_axes = get_access_axes(block.writes[0].region)
except ValueError:
return None
traits: Dict[Var, IterTrait] = {}
for iter_var in block.iter_vars:
var = iter_var.var
kind: IterKind
if _is_one(iter_var.dom.extent):
if iter_var.iter_type == tir.IterVar.CommReduce:
# for simplified case (e.g. 1x1 conv kernel)
kind = IterKind.kIter_K
else:
kind = IterKind.kIter_T
elif iter_var.iter_type == iter_var.DataPar:
if var in A_axes and var in B_axes and var in C_axes:
kind = IterKind.kIter_S
elif var in A_axes and var in C_axes:
kind = IterKind.kIter_I
elif var in B_axes and var in C_axes:
kind = IterKind.kIter_J
else:
return None
elif iter_var.iter_type == tir.IterVar.CommReduce:
if var in A_axes and var in B_axes and var not in C_axes:
kind = IterKind.kIter_K
else:
return None
else:
return None
traits[var] = IterTrait(kind, iter_var.dom.extent)
# A Gemm-kernel requires have I, J and K axes
gemm_traits = {IterKind.kIter_I, IterKind.kIter_J, IterKind.kIter_K}
if {x.kind for x in traits.values()}.intersection(gemm_traits) != gemm_traits:
return None
A_traits = [traits[iter_var.var] for iter_var in block.iter_vars if iter_var.var in A_axes]
B_traits = [traits[iter_var.var] for iter_var in block.iter_vars if iter_var.var in B_axes]
C_traits = [traits[iter_var.var] for iter_var in block.iter_vars if iter_var.var in C_axes]
block_traits = [traits[i.var] for i in block.iter_vars]
return A_traits, B_traits, C_traits, block_traits
def get_index_map(block: tir.Block,
layout: Optional[List[str]] = None) -> Optional[Tuple[tir.IndexMap, ...]]:
"""Get index maps for the block
Parameters
----------
block : tir.Block
The block to be analyzed
layout : List[str]
the target layout index map to be used.
'n' for [i, k] layout
't' for [k, j] layout
'a' for auto inference based on whether the last axis is reduction.
Returns
-------
index_maps : Optional[Tuple[tir.IndexMap]]
The index maps for the block, or None if the block is not a gemm-liked kernel
"""
if layout is None:
layout = ["n", "t", "n"]
traits = detect_iter_traits(block)
if traits is None:
return None
A_traits, B_traits, C_traits, block_traits = traits
def get_ordered_axes(region: List[Range]) -> Set[Var]:
axes: List[Var] = []
for r in region:
if not _is_one(r.extent):
raise ValueError("Expect elemwise block access")
axes.append(r.min)
return axes
def is_common_reduce(var: Var) -> bool:
for iter_var in block.iter_vars:
if iter_var.var == var and iter_var.iter_type == IterVar.CommReduce:
return True
return False
def has_common_reduce(var: Var) -> bool:
vars = collect_vars_from_expr(var)
return any(is_common_reduce(v) for v in vars)
def check_last_trait(region: List[Range]):
axes = get_ordered_axes(region)
return has_common_reduce(axes[-1])
def infer_layout(layout: str, region: List[Range], kind: str = "A"):
"""
Infer the layout based on the region and the kind of buffer
kind: "A", "B", "C"
"""
primary_iter, secondary_iter, reduction_iter = {
"A": (IterKind.kIter_I, IterKind.kIter_K, IterKind.kIter_K),
"B": (IterKind.kIter_K, IterKind.kIter_J, IterKind.kIter_K),
"C": (IterKind.kIter_I, IterKind.kIter_J, None),
}[kind]
spatial_iter = {
"A": IterKind.kIter_I,
"B": IterKind.kIter_J,
"C": None,
}[kind]
if layout == "n":
return [IterKind.kIter_S, primary_iter, secondary_iter]
elif layout == "t":
return [IterKind.kIter_S, secondary_iter, primary_iter]
elif layout == "a":
# auto inference layout
# for buffer with reduction axis, we put it as the last axis
# otherwise, we put it as the first axis
if kind == "C":
return [IterKind.kIter_S, primary_iter, secondary_iter]
else:
return ([IterKind.kIter_S, spatial_iter, reduction_iter] if check_last_trait(region)
else [IterKind.kIter_S, reduction_iter, spatial_iter])
else:
raise ValueError(f"Unknown layout {layout}")
A_index_map = make_iter_fusion_index_map(
A_traits, infer_layout(layout[0], block.reads[0].region, kind="A"))
B_index_map = make_iter_fusion_index_map(
B_traits, infer_layout(layout[1], block.reads[1].region, kind="B"))
C_index_map = make_iter_fusion_index_map(
C_traits, infer_layout(layout[2], block.writes[0].region, kind="C"))
matmul_index_map = make_iter_fusion_index_map(
block_traits,
[IterKind.kIter_S, IterKind.kIter_I, IterKind.kIter_J, IterKind.kIter_K],
)
return (
matmul_index_map,
A_index_map,
B_index_map,
C_index_map,
)
def get_in_out_dtypes(block: tir.Block) -> Tuple[str]:
"""
Detect In/Out data types for the given block based on the analysis if read/write buffers.
"""
assert len(block.reads) > 0 and len(block.writes) > 0
in_dtype = block.reads[0].buffer.dtype
out_dtype = block.writes[0].buffer.dtype
return (in_dtype, out_dtype)
def get_dequantize_block(sch, blocks) -> Optional[BlockRV]:
# check at least two input and one output
# at lease one input has uint dtype, and the output dtype is float
def is_dequantize(block: BlockRV) -> bool:
block_stmt = sch.get(block)
if len(block_stmt.reads) < 2:
return False
has_uint_input = any("uint" in str(region.buffer.dtype) for region in block_stmt.reads)
if not has_uint_input:
return False
return not (len(block_stmt.writes) != 1 or
"float" not in str(block_stmt.writes[0].buffer.dtype))
dequantize_blocks = [block for block in blocks if is_dequantize(block)]
return dequantize_blocks[0] if len(dequantize_blocks) == 1 else None
def is_identity_or_transpose_block(block_stmt: tir.Block) -> bool:
iter_types = {iter_var.iter_type for iter_var in block_stmt.iter_vars}
if iter_types != {IterVar.DataPar}:
return False, False
if not isinstance(block_stmt.body, tir.BufferStore):
return False, False
if not isinstance(block_stmt.body.value, tir.BufferLoad):
return False, False
def get_access_vars(region: List[Range]) -> List[Var]:
axes: List[Var] = []
for r in region:
if not _is_one(r.extent):
return None
axes.extend(undefined_vars(r.min))
# remove trivial axis
trivial_vars = set(
iter_var.var for iter_var in block_stmt.iter_vars if _is_one(iter_var.dom.extent))
axes = [axis for axis in axes if axis not in trivial_vars]
# remove duplicate axis
axes = [var for i, var in enumerate(axes) if i == 0 or var != axes[i - 1]]
return axes
lhs_access_vars = get_access_vars(block_stmt.reads[0].region)[-2:]
rhs_access_vars = get_access_vars(block_stmt.writes[0].region)[-2:]
is_identity = list(lhs_access_vars) == list(rhs_access_vars)
is_transpose = list(lhs_access_vars) != list(rhs_access_vars) and set(lhs_access_vars) == set(
rhs_access_vars)
return is_identity, is_transpose
def is_identity_block(block_stmt: tir.Block) -> bool:
return is_identity_or_transpose_block(block_stmt)[0]
def is_transpose_block(block_stmt: tir.Block) -> bool:
return is_identity_or_transpose_block(block_stmt)[1]
def inline_transpose_block(sch: tir.Schedule, blocks: List[tir.schedule.BlockRV]):
result_blocks = []
for block in blocks:
if not is_transpose_block(sch.get(block)):
result_blocks.append(block)
continue
try:
sch.compute_inline(block)
except Exception:
try:
sch.reverse_compute_inline(block)
except Exception:
result_blocks.append(block)
return result_blocks
def normalize_to_matmul(sch: tir.Schedule,
main_block: BlockRV,
layout: Optional[List[str]] = None) -> Optional[tir.Schedule]:
if layout is None:
layout = ["n", "t", "n"]
block_stmt = sch.get(main_block)
# let layout be 'a' to auto inference the layout
index_maps = get_index_map(block_stmt, layout=layout)
if index_maps is None:
logger.debug("Cannot find the appropriate index map for tensorcore")
return None
matmul_index_map, a_index_map, b_index_map, c_index_map = index_maps
# `skip_simplify` to avoid the bug in the 1x1 conv
block = sch.reindex(main_block, ("read", 0), skip_simplify=True)
sch.transform_layout(block, ("write", 0), a_index_map)
block = sch.reindex(main_block, ("read", 1), skip_simplify=True)
sch.transform_layout(block, ("write", 0), b_index_map)
block = sch.reindex(main_block, ("write", 0), skip_simplify=True)
sch.transform_layout(block, ("read", 0), c_index_map)
sch.transform_block_layout(main_block, matmul_index_map)
sch.mod["main"] = sch.mod["main"].with_attr("dlight.tensorcore_prenormlized", True)
return sch
def get_tensorized_func_and_tags(
func: tir.PrimFunc,
target: Target,
layout: Optional[List[str]] = None,
skip_normalize: bool = False,
allow_gemv: bool = False,
) -> Tuple[tir.PrimFunc, Dict[str, Union[List[int], int]]]:
"""
transform function to matmul if necessary (e.g. transform conv2d with im2col)
"""
if layout is None:
layout = ["a", "a", "a"]
# step1. detect whether the function can utilize tensorcore
sch = tir.Schedule(func)
root_block = get_root_block(sch)
blocks = sch.get_child_blocks(root_block)
reduction_blocks = get_reduction_blocks(sch, blocks)
if not reduction_blocks or len(reduction_blocks) != 1:
return func, None
def _can_be_tensorized(sch: tir.Schedule, block: BlockRV) -> bool:
block_stmt = sch.get(block)
conditions = []
conditions.append(len(block_stmt.reads) == 2)
conditions.append(len(block_stmt.writes) == 1)
conditions.append(
len(
collect_block_iter_vars_used_in_access_region(block_stmt,
block_stmt.writes[0].region)) > 0)
return all(conditions)
# step2. transform function to tensorcore matmul (e.g. conv2d with im2col)
def check_sm_version(arch: str) -> int:
sm_version = arch.replace("sm_", "")
return int(sm_version) if sm_version.isdigit() else -1
def analysis_tensorcore_tags(sch: tir.Schedule, block: BlockRV,
target: Target) -> Union[bool, Dict]:
tags: Dict[str, Union[List[int], int]] = {}
block_stmt = sch.get(block)
# Nvidia Only Support Tensor Core for
# devices greater than 70.
if check_sm_version(target.arch) < 70:
return False
# analysis tensorcore axis
# todo(lei): maybe we can remove this in the future
(write_buffer_region,) = block_stmt.writes
out_axis = len(write_buffer_region.buffer.shape)
tags["tensorcore_config"] = [out_axis - 2, out_axis - 1]
# analysis pipeline stage
# todo(lei): maybe we can integrate this into policy in the future
tags["pipeline_stage"] = 1
if target.kind.name == "cuda" and check_sm_version(target.arch) == 80:
# enable pipeline stage only for sm_80 devices
tags["pipeline_stage"] = 2
# analysis async copy
# todo(lei): maybe we can integrate this into policy in the future
tags["use_async_copy"] = False
if tags["pipeline_stage"] == 2 and check_sm_version(target.arch) >= 80:
# async copy only works in software pipeline.
tags["use_async_copy"] = True
# analysis intrin information
def get_ordered_axes(region: List[Range]) -> Set[Var]:
axes: List[Var] = []
for r in region:
if not _is_one(r.extent):
raise ValueError("Expect elemwise block access")
axes.append(r.min)
return axes
def is_common_reduce(var: Var) -> bool:
for iter_var in block_stmt.iter_vars:
if iter_var.var == var and iter_var.iter_type == IterVar.CommReduce:
return True
return False
def has_common_reduce(var: Var) -> bool:
vars = collect_vars_from_expr(var)
return any(is_common_reduce(v) for v in vars)
def check_last_trait(region: List[Range]):
axes = get_ordered_axes(region)
return has_common_reduce(axes[-1])
intrin_info: dict = {}
in_dtype, out_dtype = get_in_out_dtypes(block_stmt)
intrin_info["in_dtype"] = in_dtype
intrin_info["out_dtype"] = out_dtype
if 70 <= check_sm_version(target.arch) < 80 and out_dtype == "int32":
# INT32 Accum TensorCore only supports SM Version > 32.
return False
# if the last dimension is reduce axis, the B is transposed
intrin_info["trans_b"] = check_last_trait(block_stmt.reads[1].region)
if func.attrs is not None and "input_transform_kind" in func.attrs:
intrin_info["input_transform_kind"] = func.attrs["input_transform_kind"]
if func.attrs is not None and "weight_transform_kind" in func.attrs:
intrin_info["weight_transform_kind"] = func.attrs["weight_transform_kind"]
tags["intrin_info"] = intrin_info
# Analysis Block Reduction Optimization
# Currently, we only support block reduction depth 2 for small M
# When the func is a dequantize like ops, we should consider the M
require_block_reduce = False
# And we only support float16 for now
if (hasattr(func.attrs, "dequantize_info") and in_dtype in ["bfloat16", "float16"]):
for arg in func.params:
inp_shape = func.buffer_map[arg].shape
M = inp_shape[0]
if isinstance(M, tir.IntImm) and M <= 128:
require_block_reduce = True
break
if require_block_reduce and check_sm_version(target.arch) == 80:
tags["block_reduction_depth"] = 2
return tags
(main_block,) = reduction_blocks
if _can_be_tensorized(sch, main_block) is None:
return func, None
block_stmt = sch.get(main_block)
if target.kind.name == "cuda" and check_sm_version(target.arch) >= 70:
in_dtype, out_dtype = get_in_out_dtypes(block_stmt)
if not is_tensorcore_supported_precision(in_dtype, out_dtype, arch=get_arch(target)):
logger.debug(
f"The input and output dtype ({in_dtype}, {out_dtype})is not supported by tensorcore"
)
return func, None
# reindex and transform functions
# Normalize tensor functions to C[S, I, J] += A[S, I, K] * B[S, J, K]
# or C[S, I, J] += A[S, I, K] * B[S, K, J]
# skip normalize when we want to detect tags only.
if not skip_normalize:
sch = normalize_to_matmul(sch, main_block, layout)
if sch is None:
return func, None
block_stmt = sch.get(main_block)
# 16 for 16 bits tensor core while 32 for 8bits tensorcore.
minimal_tensorize_spatial_threshold = 16
minimal_tensorize_reduce_threshold = 16 if in_dtype in ["bfloat16", "float16"] else 32
# the batch dimension is not taken into consideration.
for item_var in block_stmt.iter_vars[1:]:
extent = item_var.dom.extent
iter_type = item_var.iter_type
if iter_type is IterVar.DataPar:
minimal_tensorize_threshold = minimal_tensorize_spatial_threshold
elif iter_type is IterVar.CommReduce:
minimal_tensorize_threshold = minimal_tensorize_reduce_threshold
else:
raise ValueError(f"Unknown IterVar type {iter_type}")
if (isinstance(extent, tir.expr.IntImm) and extent.value < minimal_tensorize_threshold):
return func, None
tags = analysis_tensorcore_tags(sch, main_block, target)
return sch.mod["main"], tags
return func, None
def get_propagate_map(trans: bool = True, dtype="float16", matrix_name="A", index_dtype="int32"):
from bitblas.tl.mma_layout import ( # pylint: disable=import-outside-toplevel
ldmatrix_32x8_to_shared_16x16_layout, ldmatrix_trans_32x8_to_shared_16x16_layout,
ldmatrix_32x16_to_shared_16x32_layout_a, ldmatrix_32x16_to_shared_16x32_layout_b,
)
assert dtype in [
"bfloat16",
"float16",
"int8",
"e4m3_float8",
"e5m2_float8",
], "Only support bfloat16, float16, int8, e4m3_float8, e5m2_float8"
# TODO(lei): actually should analyze based on bits instead of dtype
if dtype in ["bfloat16", "float16"]:
ldmatrix_layout = ldmatrix_32x8_to_shared_16x16_layout
ldmatrix_layout_trans = ldmatrix_trans_32x8_to_shared_16x16_layout
elif dtype in ["int8", "e4m3_float8", "e5m2_float8"]:
# int8 mma only support 32x16 to 16x32 layout
if matrix_name == "A" and trans is False:
ldmatrix_layout = ldmatrix_32x16_to_shared_16x32_layout_a
elif matrix_name == "B" and trans is True:
ldmatrix_layout = ldmatrix_32x16_to_shared_16x32_layout_b
else:
raise ValueError("Unknown matrix name ", matrix_name)
# IntraWarp memory layout was occurred by ldmatrix, we should lift the ld_matrix out
def ldmatrix_permutation_16x16_32x8_16x16(kernel_i, kernel_j):
thread_id = kernel_i * 2 + kernel_j // 8
local_id = kernel_j % 8
return ldmatrix_layout(thread_id, local_id)
def ldmatrix_trans_permutation_16x16_32x8_16x16(kernel_i, kernel_j):
thread_id = kernel_i * 2 + kernel_j // 8
local_id = kernel_j % 8
return ldmatrix_layout_trans(thread_id, local_id)
def ldmatrix_permutation_16x32_32x16_32x16(kernel_i, kernel_j):
thread_id = kernel_i * 2 + kernel_j // 16
local_id = kernel_j % 16
return ldmatrix_layout(thread_id, local_id)
if dtype in ["bfloat16", "float16"]:
ldmatrix_index_map = (
ldmatrix_trans_permutation_16x16_32x8_16x16
if trans else ldmatrix_permutation_16x16_32x8_16x16)
else:
ldmatrix_index_map = ldmatrix_permutation_16x32_32x16_32x16
ldmatrix_index_map = IndexMap.from_func(ldmatrix_index_map, index_dtype=index_dtype)
# TODO(lei): index_dtype should be analyzed from the schedule
row, col = [16, 16] if dtype in ["bfloat16", "float16"] else [16, 32]
inversed_index_map = ldmatrix_index_map.inverse([row, col])
return ldmatrix_index_map, inversed_index_map
# This function is used to get the index map for the stage3 of the
# Ladder weight propagation, which can be used to avoid the ldmatrix
# Instructions.
def get_ladder_stage3_map(dtype="float16", index_dtype="int32"):
def shared_32x8_to_mma_32x8_layout(i, j):
thread_id = (i % 8) * 4 + (j // 2)
local_id = (i // 8) * 2 + (j % 2)
return thread_id, local_id
def shared_32x16_to_mma_32x16_layout(i, j):
thread_id = (i % 8) * 4 + (j // 4)
local_id = (i // 8) * 4 + (j % 4)
return thread_id, local_id
assert dtype in [
"bfloat16",
"float16",
"int8",
"e4m3_float8",
"e5m2_float8",
], "Only support float16, int8, e4m3_float8, e5m2_float8"
if dtype in ["bfloat16", "float16"]:
stage3_layout = shared_32x8_to_mma_32x8_layout
elif dtype in ["int8", "e4m3_float8", "e5m2_float8"]:
stage3_layout = shared_32x16_to_mma_32x16_layout
else:
raise ValueError("Unknown dtype ", dtype)
# IntraWarp memory layout was occurred by ldmatrix, we should lift the ld_matrix out
def ladder_stage3_permutation_16x16_32x8_32x8_16x16(kernel_i, kernel_j):
thread_id = kernel_i * 2 + kernel_j // 8
local_id = kernel_j % 8
new_thread_id, new_local_id = stage3_layout(thread_id, local_id)
new_kernel_i = (new_thread_id * 8 + new_local_id) // 16
new_kernel_j = (new_thread_id * 8 + new_local_id) % 16
return new_kernel_i, new_kernel_j
def ladder_stage3_permutation_16x32_32x16_32x16_16x32(kernel_i, kernel_j):
thread_id = kernel_i * 2 + kernel_j // 16
local_id = kernel_j % 16
new_thread_id, new_local_id = stage3_layout(thread_id, local_id)
new_kernel_i = (new_thread_id * 16 + new_local_id) // 32
new_kernel_j = (new_thread_id * 16 + new_local_id) % 32
return new_kernel_i, new_kernel_j
if dtype in ["bfloat16", "float16"]:
stage3_index_map = ladder_stage3_permutation_16x16_32x8_32x8_16x16
else:
stage3_index_map = ladder_stage3_permutation_16x32_32x16_32x16_16x32
stage3_index_map = IndexMap.from_func(stage3_index_map, index_dtype=index_dtype)
# TODO(lei): index_dtype should be analyzed from the schedule
row, col = [16, 16] if dtype in ["bfloat16", "float16"] else [16, 32]
inversed_index_map = stage3_index_map.inverse([row, col])
return stage3_index_map, inversed_index_map
def layout_propagate_chain(
sch: tir.Schedule,
start_block: BlockRV,
start_buffer: tir.Buffer,
end_block: BlockRV,
index_map: IndexMap,
):
# some layout transformation may only apply to the last n dimensions
# propagate the layout transformation to the chain of blocks
block = start_block
buffer = start_buffer
index_map = index_map
while True:
last_buffer = buffer
producers = sch.get_producers(block)
if len(producers) == 0:
break
for producer in producers:
if len(sch.get(producer).writes) != 1:
return index_map
if sch.get(producer) == sch.get(end_block):
return index_map
(write,) = sch.get(producer).writes
read = find_first_similar_region(sch.get(producer).reads, last_buffer)
if write.buffer == buffer:
block = producer
buffer = read.buffer
write_indices = [r.min for r in write.region]
read_indices = [r.min for r in read.region]
# reverse index map from [vi // x] -> [vi * x] to match the inconsistent layout
tmp_index_map = IndexMap(write_indices, read_indices, None)
tmp_index_map = tmp_index_map.non_surjective_inverse(write.buffer.shape)[0]
# if dequantize like ops are used, the scaling factor should be considered
# to be applied to the final indices
scaling_factor = 1
for i, j in zip(write.buffer.shape, read.buffer.shape):
scaling_factor *= i // j
final_indices = list(
index_map.map_indices(tmp_index_map.map_indices(write_indices)))
final_indices[-1] = final_indices[-1] // scaling_factor
index_map = IndexMap(
write_indices,
final_indices,
None,
)
if buffer == last_buffer:
break
return index_map
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from .node import PrimFuncNode # noqa: F401
from .rasterization import NoRasterization, Rasterization2DRow, Rasterization2DColumn # noqa: F401
from .hint import Hint # noqa: F401
from .policy import DefaultPolicy, TensorCorePolicy # noqa: F401
from ..arch import TileDevice, CUDA # noqa: F401
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Benefit For Carver Schedule"""
class Block:
def __init__(self, start, end, is_free):
self.start = start
self.end = end
self.is_free = is_free
def size(self) -> int:
return self.end - self.start
def merge(self, other):
assert self.is_free == other.is_free
self.start = min(self.start, other.start)
self.end = max(self.end, other.end)
def __repr__(self) -> str:
return "<Block offset={} size={}>".format(self.start, self.size())
class BestFit:
def __init__(self, align=32):
self.limit = 0
self.list = []
self.align = align
def malloc(self, size) -> Block:
size = (size + self.align - 1) // self.align * self.align
found = None
for block in self.list:
if block.is_free and block.size() >= size and not found or found.size() > block.size():
found = block
if found:
found.is_free = False
remain = found.size() - size
if remain != 0:
found.end -= remain
self.list.insert(
self.list.index(found) + 1, Block(found.end, found.end + remain, True))
return found
elif len(self.list) > 0 and self.list[-1].is_free:
add = size - self.list[-1].size()
self.list[-1].end += add
self.limit = self.list[-1].end
self.list[-1].is_free = False
return self.list[-1]
else:
block = Block(self.limit, self.limit + size, False)
self.list.append(block)
self.limit += size
return block
def free(self, block: Block) -> None:
assert not block.is_free
idx = self.list.index(block)
self.list[idx] = Block(block.start, block.end, True)
if idx + 1 < len(self.list) and self.list[idx + 1].is_free:
self.list[idx].merge(self.list[idx + 1])
self.list.pop(idx + 1)
if idx - 1 >= 0 and self.list[idx - 1].is_free:
self.list[idx].merge(self.list[idx - 1])
self.list.pop(idx - 1)
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Hint definition for schedule"""
from tvm import DataType
from typing import Dict, List, Tuple
from . import PrimFuncNode
import numpy as np
from .rasterization import *
class TensorCoreExtraConfig:
"""
This class is used to store extra information for tensorcore
"""
def __init__(
self,
AS_shape: Tuple[int],
BS_shape: Tuple[int],
AF_shape: Tuple[int],
BF_shape: Tuple[int],
tc_axis: Tuple[int],
) -> None:
self.AS_shape: Tuple[int] = AS_shape
self.BS_shape: Tuple[int] = BS_shape
self.AF_shape: Tuple[int] = AF_shape
self.BF_shape: Tuple[int] = BF_shape
self.tc_axis: Tuple[int] = tc_axis
class Stride:
"""
Manages stride information for a given axis of a tensor.
"""
def __init__(self, stride: int = 1, ax: int = -1) -> None:
# which axis to put stride on
self._ax: int = int(ax)
# the stride size of the axis
self._stride: int = int(stride)
@property
def ax(self) -> int:
return self._ax
@property
def stride(self) -> int:
return self._stride
def compute_strides_from_shape(self, shape: List[int]) -> List[int]:
ndim = len(shape)
strides = [1 for _ in shape]
for i in range(ndim - 2, -1, -1):
if i == self.ax:
strides[i] = self.stride
else:
strides[i] = int(strides[i + 1] * shape[i + 1])
return strides
def compute_elements_from_shape(self, shape: List[int]) -> int:
original_shape = np.prod(shape)
if not self.is_valid():
strided_elem = original_shape
else:
assert self.ax < len(shape)
strided_elem = np.prod(shape[0:self.ax + 1]) * self.stride
assert strided_elem >= original_shape
return int(strided_elem)
def is_valid(self) -> bool:
return self.ax >= 0
def __repr__(self) -> str:
return f"<Stride, {self._ax}, {self._stride}>"
class TileDict:
"""
Manages tiling information and configurations for computational tasks.
"""
def __init__(self, output_tile) -> None:
self.output_tile = output_tile
# schedule config
self.tile_map = {}
self.rstep_map = {}
self.cached_tensors_map = {}
self.output_strides_map = {}
self.tensor_strides_map = {}
# analysis
self.traffic = -1
self.smem_cost = -1
self.block_per_SM = -1
self.num_wave = -1
self.grid_size = -1
self.valid = True
def get_tile(self, func) -> List[int]:
return self.tile_map[func]
def get_rstep(self, func) -> Dict[str, int]:
return self.rstep_map
def __hash__(self) -> int:
return hash(tuple(self.output_tile))
class IntrinInfo:
"""
The information of tensorcore intrinsic related information
"""
def __init__(
self,
in_dtype: str,
out_dtype: str,
trans_b: bool,
input_transform_kind: int = 0,
weight_transform_kind: int = 0,
) -> None:
self.in_dtype = in_dtype
self.out_dtype = out_dtype
self.trans_a = False
self.trans_b = trans_b
self.input_transform_kind = input_transform_kind
self.weight_transform_kind = weight_transform_kind
def __repr__(self) -> str:
return f"<IntrinInfo, {self.in_dtype}, {self.out_dtype}, {self.trans_b}, {self.propagate_b}>"
def is_input_8bit(self) -> bool:
return DataType(self.in_dtype).bits == 8
@property
def smooth_a(self) -> bool:
return self.input_transform_kind >= 2
@property
def smooth_b(self) -> bool:
return self.weight_transform_kind >= 2
@property
def inter_transform_a(self) -> bool:
return self.input_transform_kind >= 1
@property
def inter_transform_b(self) -> bool:
return self.weight_transform_kind >= 1
class Hint(object):
"""
Central configuration class for managing various parameters of computational tasks.
"""
def __init__(self) -> None:
self.arch = None
self.use_tc = None # todo(lei): this should be renamed.
# Special axes tiling info
self.block = []
self.thread = []
# Special axes for MFMA
self.warp = []
# Reduce axes tiling info
self.rstep = []
self.reduce_thread = []
self.rasterization_plan = NoRasterization()
self.cached_tensors = []
self.output_strides = {}
self.schedule_stages = None
# Config for block reduction
self.block_reduction_depth = None # type: int
# TL Specific
# Split-K factor for SM waste optimization
self.split_k_factor: int = 1
# Experimental
self._raxis_order = []
self._step = []
self.vectorize: Dict[str, int] = {}
self.pipeline_stage = 1
self.use_async = False
self.opt_shapes: Dict[str, int] = {}
self.intrin_info = IntrinInfo("float16", "float16", True)
self.shared_scope: str = "shared"
self.pass_context: Dict = {}
def to_dict(self) -> Dict:
dic = {}
dic["block"] = self.block
if self.use_tc:
dic["warp"] = self.warp
else:
dic["thread"] = self.thread
dic["rstep"] = self.rstep
if np.prod(self.reduce_thread) > 1:
dic["reduce_thread"] = self.reduce_thread
if self.use_tc:
dic["use_tc"] = self.use_tc
if self.output_strides:
dic["strides"] = {}
for k, stride in self.output_strides.items():
if stride.is_valid():
dic["strides"][k] = stride
if len(dic["strides"]) == 0:
del dic["strides"]
if np.prod(self._step) > 1:
dic["step"] = self._step
if self._raxis_order != []:
dic["raxis_order"] = self._raxis_order
if self.vectorize != {}:
dic["vectorize"] = self.vectorize
if self.pipeline_stage != 1:
dic["pipeline_stage"] = self.pipeline_stage
if self.block_reduction_depth is not None:
dic["block_reduction_depth"] = self.block_reduction_depth
return dic
@classmethod
def from_dict(cls, dic: Dict) -> "Hint":
hint = cls()
for k, v in dic.items():
setattr(hint, k, v)
return hint
def tensorcore_legalization(self):
# only keep the last 2 axes for tensorcore
self.warp = self.warp[-2:]
self.block = self.block[-2:]
return self
@property
def raxis_order(self) -> List[int]:
if self._raxis_order != []:
return self._raxis_order
return list(range(len(self.rstep)))
@property
def step(self) -> List[int]:
if self._step != []:
return self._step
return [1 for _ in self.block]
def __repr__(self) -> str:
return str(self.to_dict())
def complete_config(self, node: PrimFuncNode):
# analysis pass context, for int8 mma, we should merge static shared memory
merge_static_smem = False
# int32 and float32 accum may take too much shared memory
if self.use_tc and self.intrin_info.out_dtype in ["float32", "int32"]:
merge_static_smem = True
# Always merge dynamic shared memory
if self.shared_scope == "shared.dyn":
merge_static_smem = True
self.pass_context = {"tir.merge_static_smem": merge_static_smem}
return self
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""PrimFunc Wrapper and Block information Analaysis"""
import tvm
from tvm import tir
from tvm.tir import IterVar, PrimFunc
from typing import Any, Dict, List, Tuple, Optional
from tvm.tir.schedule.schedule import BlockRV
import numpy as np
import functools
from ..analysis import BlockInfo, get_reduction_blocks
from .. import analysis
from .. import normalize_prim_func
from .shape_inference import get_analyzer_by_tir
def pre_order_traverse(block_analyzer, blocks, func):
visited = set()
def _traverse(block):
if block in visited:
return
visited.add(block)
for dep_blocks in block_analyzer.get_consumer_blocks(block):
_traverse(dep_blocks)
func(block)
for block in blocks:
_traverse(block)
class BlockAnalyzer(object):
def __init__(self, sch) -> None:
self.sch: tir.Schedule = sch
self.block_infos: List[BlockInfo] = normalize_prim_func(self.sch)
def get_block_name(self, block: BlockRV) -> str:
return self.sch.get(block).name_hint
def get_block_info(self, block: BlockRV) -> BlockInfo:
for block_info in self.block_infos:
if self.get_block_name(block) == block_info.name:
return block_info
return None
def get_spatial_axis(self, block: BlockRV) -> List[IterVar]:
block_info = self.get_block_info(block)
axis = []
for iter in block_info.iters:
if iter.kind == "S":
axis.append(iter)
return axis
def get_reduce_axis(self, block: BlockRV) -> List[IterVar]:
block_info = self.get_block_info(block)
raxis = []
for iter in block_info.iters:
if iter.kind == "R":
raxis.append(iter)
return raxis
def get_input_buffers(self, block: BlockRV) -> List[tir.Buffer]:
buffers = []
for read in self.sch.get(block).reads:
buffers.append(read.buffer)
return buffers
def get_output_buffers(self, block: BlockRV) -> List[tir.Buffer]:
buffers = []
for write in self.sch.get(block).writes:
buffers.append(write.buffer)
return buffers
def get_buffers(self, block: BlockRV) -> List[tir.Buffer]:
return self.get_input_buffers(block) + self.get_output_buffers(block)
def get_producer_blocks(self, block: BlockRV) -> List[BlockRV]:
return self.sch.get_producers(block)
def get_consumer_blocks(self, block: BlockRV) -> List[BlockRV]:
return self.sch.get_consumers(block)
class Node(object):
def __init__(self, tags: Optional[Dict] = None) -> None:
if tags is None:
tags = {}
self._dtypes = []
self._tag: Dict = {}
for tag in tags:
self.add_tag(tag, tags[tag])
def set_tag(self, k: str, v: Any = True) -> None:
self.add_tag(k, v)
def add_tag(self, k: str, v: Any = True) -> None:
self._tag[k] = v
def get_tag(self, k: str) -> Any:
if k not in self._tag:
return None
return self._tag[k]
class PrimFuncNode(Node):
def __init__(self, prim_func: PrimFunc, tags: Optional[Dict] = None) -> None:
super().__init__(tags)
self.prim_func = self._specialize_func(prim_func)
self.sch: tir.Schedule = tir.Schedule(self.prim_func)
self.block_analyzer: BlockAnalyzer = BlockAnalyzer(self.sch)
self.schedule_stages: List[BlockRV] = []
self.blocks: List[BlockRV] = []
self.output_blocks: List[BlockRV] = None
self.reduction_block: BlockRV = None
self.raxis = []
self.input_buffers = []
self.output_buffers = []
self.buffers = []
self.args = []
self._analysis_funcinfo()
self.ana = get_analyzer_by_tir(self.block_analyzer, self.blocks)
def _specialize_func(self, func: PrimFunc):
# Specialize the function to make it more friendly for analysis.
# set attrs
for k, v in func.attrs.items():
self.set_tag(k, v)
if self.get_tag("is_speclized"):
return func
opt_shapes = self.get_tag("opt_shapes")
if opt_shapes:
for name, shape in opt_shapes.items():
var = analysis.find_var_from_func(func, name)
if var is not None:
func = func.specialize({var: shape.astype(var.dtype)})
return func
def _analysis_funcinfo(self):
root_block = analysis.get_root_block(self.sch)
blocks = self.sch.get_child_blocks(root_block)
self.blocks = blocks
self.output_blocks = self.sch.get_output_blocks(root_block)
reduction_blocks = get_reduction_blocks(self.sch, blocks)
if reduction_blocks is None:
self.reduction_block = None
self.schedule_stages.append(*self.output_blocks)
else:
# analysis on the last reduction block
self.reduction_block = reduction_blocks[-1]
# set raxis
reduce_block_info = self.block_analyzer.get_block_info(self.reduction_block)
for iter in reduce_block_info.iters:
if iter.kind == "R":
self.raxis.append(iter)
self.schedule_stages.append(self.reduction_block)
# collect output buffers
for output_block in self.output_blocks:
for write in self.sch.get(output_block).writes:
if write not in self.output_buffers:
self.output_buffers.append(write.buffer)
for param in self.prim_func.params:
if param not in self.prim_func.buffer_map:
# in case of dynamic symbolic may in params
continue
buffer = self.prim_func.buffer_map[param]
if buffer not in self.output_buffers:
self.input_buffers.append(buffer)
self.args = self.input_buffers + self.output_buffers
self.buffers = [buffer for buffer in self.prim_func.buffer_map.values()]
# set dtype
self.set_dtype(tvm.DataType(self.output_buffers[0].dtype))
def get_opt_shape(self, name) -> int:
opt_shapes = self.get_tag("opt_shapes")
if opt_shapes is None:
return None
return opt_shapes[name]
def extent_wrapper(self, value) -> int:
if isinstance(value, tvm.tir.Var):
return self.get_opt_shape(value.name)
elif isinstance(value, tvm.tir.IntImm):
return int(value)
else:
return value
@functools.lru_cache()
def get_space_dim(self) -> List[int]:
dim_size = []
if self.reduction_block:
block_info = self.block_analyzer.get_block_info(self.reduction_block)
for iter in block_info.iters:
if iter.kind == "S":
if isinstance(iter.dom.extent, tvm.tir.IntImm):
dim_size.append(int(iter.dom.extent))
else:
assert isinstance(iter.dom.extent, tvm.tir.Var)
dim_size.append(self.get_opt_shape(iter.dom.extent.name))
else:
# assume outer stage has the same shape
loops = self.sch.get_loops(self.schedule_stages[0])
for loop in loops:
dim_size.append(int(self.sch.get(loop).extent))
return [int(x) for x in dim_size]
def set_dtype(self, dtype: tvm.DataType, id=0) -> None:
assert isinstance(dtype, tvm.DataType), type(dtype)
if dtype == tvm.DataType("bool"):
dtype = tvm.DataType("int8")
if len(self._dtypes) <= id:
self._dtypes.extend([None for _ in range(id - len(self._dtypes) + 1)])
elif self._dtypes[id] is not None:
assert self._dtypes[id] == dtype, (self._dtypes, dtype)
self._dtypes[id] = dtype
def get_dtype(self, id=0) -> tvm.DataType:
return self._dtypes[id]
def get_buffer_dtype(self, buffer: tir.Buffer) -> tvm.DataType:
return tvm.DataType(buffer.dtype)
def propagate(self, tile, rstep: Optional[Dict] = None, targets=None):
if rstep is None:
rstep = {}
shape = {
self.block_analyzer.get_output_buffers(block)[0].name: [
tvm.arith.ConstIntBound(0, val - 1) for val in tile
] for block in self.schedule_stages
}
return self.ana.infer(shape, rstep, targets)
def propagate_inputs(self, tile, rstep: Optional[Dict] = None) -> List[List[int]]:
if rstep is None:
rstep = {}
read_idx_offset = len(self.input_buffers)
targets = [t.name for t in self.args[:read_idx_offset]]
shapes, intermediate_bind = self.propagate(tile, rstep, targets)
results = []
for i, arg in enumerate(self.args[:read_idx_offset]):
if arg.name in intermediate_bind:
results.append(shapes[arg.name])
continue
# should not exceed original shape
trimmed_shape = [
self.extent_wrapper(i)
for i in list(map(min, zip(shapes[arg.name], self.input_buffers[i].shape)))
]
results.append(trimmed_shape)
return results
# Propagate inputs only on reduction block
def propagate_inputs_on_reduction(self, tile, rstep: Optional[Dict] = None) -> List[List[int]]:
if rstep is None:
rstep = {}
reduction_block = self.reduction_block
args = self.block_analyzer.get_input_buffers(reduction_block)
targets = [t.name for t in args]
shapes, intermediate_bind = self.propagate(tile, rstep, targets)
results = []
for i, arg in enumerate(args):
if arg.name in intermediate_bind:
results.append(shapes[arg.name])
continue
# should not exceed original shape
propagate_shape = shapes[arg.name]
buffer_shape = args[i].shape
if len(buffer_shape) > len(propagate_shape):
buffer_shape = buffer_shape[-len(propagate_shape):]
trimmed_shape = [
self.extent_wrapper(j) for j in list(map(min, zip(propagate_shape, buffer_shape)))
]
results.append(trimmed_shape)
return results
def propagate_outputs(self, tile, rstep: Optional[Dict] = None) -> List[List[int]]:
if rstep is None:
rstep = {}
read_idx_offset = len(self.input_buffers)
targets = [t.name for t in self.args[read_idx_offset:]]
shapes, _ = self.propagate(tile, rstep, targets)
results = []
for i, arg in enumerate(self.args[read_idx_offset:]):
# should not exceed original shape
trimmed_shape = list(map(min, zip(shapes[arg.name], self.input_buffers[i].shape)))
results.append(trimmed_shape)
return results
def propagate_reduction_inputs(self,
shape,
rstep: Optional[Dict] = None) -> Dict[str, List[int]]:
if rstep is None:
rstep = {}
if self.reduction_block is None:
return {}
targets = [b.name for b in self.block_analyzer.get_input_buffers(self.reduction_block)]
results, _ = self.propagate(shape, rstep, targets)
return results
def get_reduce_inputs_dtype(self):
if self.reduction_block is None:
return {}
return {
b.name: tvm.DataType(b.dtype)
for b in self.block_analyzer.get_input_buffers(self.reduction_block)
}
@functools.lru_cache()
def infer_tensorcore_axis(self) -> Tuple[int]:
# axis is fixed for one expression, so only inference and cached
assert self.get_tag("tensorcore_config")
C_ax_m, C_ax_n = self.get_tag("tensorcore_config")
wmma_m, wmma_n, wmma_k = [16, 16, 16] # just for testing, any number is ok
output_buffer_shape = (
self.block_analyzer.sch.get(self.reduction_block).writes[0].buffer.shape)
valid_region = []
for region in output_buffer_shape:
if region.value == 1:
continue
valid_region.append(region)
num_nvalid_regions = len(output_buffer_shape) - len(valid_region)
self.set_tag("num_nvalid_regions", num_nvalid_regions)
def get_cl_shapes(c_ax_m, c_ax_n, num_nvalid_regions):
spatial_dim = self.get_space_dim()
assert len(valid_region) == len(
spatial_dim), f" {valid_region} mismatch with {spatial_dim}"
cl_shapes = [1] * len(spatial_dim)
cl_shapes[c_ax_m - num_nvalid_regions] = wmma_m
cl_shapes[c_ax_n - num_nvalid_regions] = wmma_n
return cl_shapes
CL_shape = get_cl_shapes(C_ax_m, C_ax_n, num_nvalid_regions)
self.set_tag("tensorcore_config", [s - num_nvalid_regions for s in [C_ax_m, C_ax_n]])
shapes = self.propagate_reduction_inputs(CL_shape, {x.var.name: 1 for x in self.raxis})
A_deps, B_deps = shapes.values()
A_ax_m = A_deps.index(wmma_m)
B_ax_n = B_deps.index(wmma_n)
CL_shape = [1] * len(self.get_space_dim())
shapes = self.propagate_reduction_inputs(CL_shape, {x.var.name: wmma_k for x in self.raxis})
A_deps, B_deps = shapes.values()
A_ax_k = len(A_deps) - 1 - A_deps[::-1].index(wmma_k)
B_ax_k = len(B_deps) - 1 - B_deps[::-1].index(wmma_k)
tc_axis = (A_ax_m, A_ax_k, B_ax_k, B_ax_n, C_ax_m, C_ax_n)
return tc_axis
def footprint(self, shape, rstep, stride_map: Optional[Dict] = None) -> int:
if stride_map is None:
stride_map = {}
result = 0
shapes, _ = self.propagate(shape, rstep)
def is_broadcast_pattern(buffer, output_buffer):
return (buffer in self.args and
len(shapes[output_buffer.name]) > len(shapes[buffer.name]) and
np.prod(shapes[output_buffer.name]) > np.prod(shapes[buffer.name]))
def is_after_reduce_stage(block):
if not self.reduction_block:
return False
reduce_dependent_blocks = getattr(self, "reduce_dependent_blocks", None)
if reduce_dependent_blocks is None:
reduce_dependent_blocks = set()
pre_order_traverse(
self.block_analyzer,
[self.reduction_block],
lambda block: reduce_dependent_blocks.add(block),
)
self.reduce_dependent_blocks = reduce_dependent_blocks
return block not in reduce_dependent_blocks
# compute cached stages
cached_tensor = []
for block in self.blocks:
output_buffer = self.block_analyzer.get_output_buffers(block)[0]
for buffer in self.block_analyzer.get_input_buffers(block):
cache = buffer.name not in cached_tensor and (
is_broadcast_pattern(buffer, output_buffer) or
self.block_analyzer.get_block_info(block).is_reduction)
if not cache:
continue
cached_tensor.append(buffer.name)
if is_after_reduce_stage(block):
continue # cache after reduce op can often reuse buffer in reduce stage
if buffer.name in stride_map:
num_elem = stride_map[buffer.name].compute_elements_from_shape(
shapes[buffer.name])
else:
num_elem = np.prod(shapes[buffer.name])
buffer_len = num_elem * int((tvm.DataType(buffer.dtype).bits + 7) // 8)
buffer_len = (buffer_len + 31) // 32 * 32
result += buffer_len
return result, cached_tensor
def get_input_buffers(self) -> List[tir.Buffer]:
return self.block_analyzer.input_buffers
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from .default import DefaultPolicy # noqa: F401
from .tensorcore import TensorCorePolicy # noqa: F401
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from typing import List
import numpy as np
def get_all_factors(n: int) -> List[int]:
# Calculate the square root of n and round it up to the nearest integer
n0 = int(np.ceil(np.sqrt(n)))
# Find all divisors of n that are less than n0
val = np.where(n % np.arange(1, n0) == 0)[0] + 1
# If n is a perfect square, add the square root to the list of factors
mid = np.array([], dtype=int) if n0 * n0 != n else [n0]
# Combine the factors and their corresponding larger pair factors
return [int(x) for x in np.concatenate([val, mid, n // val[::-1]])]
def factorize(n: int) -> List[int]:
i = 2 # Start with the smallest prime number
result = []
# Iterate through numbers to find factors
while n > 1:
if n % i == 0: # If i is a factor of n
n //= i # Divide n by i and keep the integer part
result.append(i)
else:
i += 1 # Try the next number
return result
def coalesced_factor(subtensor: List[int], tensor: List[int]) -> int:
# If the last dimension of the subtensor and tensor differ, or subtensor has only one dimension
if subtensor[-1] != tensor[-1] or len(subtensor) == 1:
return subtensor[-1]
else:
# Recursively calculate the coalesced factor for the remaining dimensions
return subtensor[-1] * coalesced_factor(subtensor[:-1], tensor[:-1])
def coalesced_tensor_shape(subtensor: List[int], tensor: List[int], transaction_size: int) -> int:
# Calculate the total number of elements in the subtensor
bytes = int(np.prod(subtensor))
if bytes == 0:
return 0
# Calculate the coalesced factor for the subtensor
factor = int(coalesced_factor(subtensor, tensor))
# Compute the shape of the coalesced tensor
return transaction_size * bytes / min(transaction_size, factor)
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Policy for cuda core schedule"""
import functools
import math
from queue import PriorityQueue
from typing import Iterable, Dict, List, Optional
import numpy as np
import tvm
from ...arch import TileDevice
from ..bestfit import BestFit
from ..hint import Hint, Stride, TileDict
from .common import coalesced_factor, coalesced_tensor_shape, factorize, get_all_factors
from ..node import PrimFuncNode
from ..rasterization import NoRasterization
class DefaultPolicy:
"""
Default Policy for fastdlight, a heuristic plan that tries to
minimize memory traffic and maximize parallelism.for BitBLAS Schedule.
"""
def __init__(self,
func: tvm.tir.PrimFunc,
arch: TileDevice,
tags: Optional[Dict] = None) -> None:
if tags is None:
tags = {}
self.arch = arch
self.prim_func_node = PrimFuncNode(func, tags)
self.ordered_nodes = [self.prim_func_node]
self.output_nodes = [self.prim_func_node]
def emit_config(self, topk: int) -> List[Hint]:
base_tile = self.get_base_tile()
if base_tile is None:
return []
rstep_map = self._assign_reduce_step(self.prim_func_node)
smem_tile_condidates = self.dfs_smem_tile(base_tile, rstep_map)
results = []
for td in smem_tile_condidates:
if not self.check_tile_shape_isvalid(td):
continue
self._expand_reduce_axis(td)
for codegen_dicts in self.assign_block_size(td):
results.append(codegen_dicts)
if len(results) >= topk:
break
if len(results) >= topk:
break
return results
def dfs_smem_tile(self, init_tile, rstep_map) -> Iterable[TileDict]:
_steps = [get_all_factors(n) for n in self.prim_func_node.get_space_dim()]
steps = [step[step.index(t):] for step, t in zip(_steps, init_tile)]
for i in range(len(steps)):
added = list(
filter(
lambda s: s < steps[i][-1] and s > steps[i][0] and s not in steps[i],
[2, 4, 8, 16, 32],
))
steps[i].extend(added)
steps[i] = sorted(steps[i])
visited_tiles = {}
queue = PriorityQueue()
def prio(td: TileDict):
return (td.traffic + 1) * td.num_wave
def add_to_queue(tile):
if tuple(tile) in visited_tiles:
return
td = self.compute_tile_dict(tile, rstep_map)
visited_tiles[tuple(tile)] = td
if td.valid:
queue.put([prio(td), tile])
add_to_queue(init_tile)
while not (queue.empty() or len(visited_tiles) > 2000):
_, tile = queue.get()
dim_ids = [step.index(t) for step, t in zip(steps, tile)]
for i in reversed(range(len(dim_ids))):
if dim_ids[i] + 1 < len(steps[i]):
new_tile = tile.copy()
new_tile[i] = steps[i][dim_ids[i] + 1]
add_to_queue(new_tile)
visited_tiles = filter(lambda td: td.valid, visited_tiles.values())
sorted_tiles = sorted(visited_tiles, key=lambda td: prio(td))
return sorted_tiles
def get_base_tile(self):
"""
Gets the minimum tile configuration that satisfies no redundancy in computation.
Returns
-------
List[int]
The base tile configuration, which is a list of 1s equal in length to the space dimensions
of the primary function node.
"""
shape = self.prim_func_node.get_space_dim()
base_tile = [1 for _ in shape]
return base_tile
# handles multiple output cases
def _get_output_tile_map(self, tile):
"""
Handles multiple output cases by mapping output nodes to their respective tile configurations.
Parameters
----------
tile : List[int]
The tile configuration.
Returns
-------
Dict
A dictionary mapping the primary function node to its corresponding tile configuration
based on the output nodes' space dimensions.
"""
tile_map = {}
tile_map[self.prim_func_node] = [
tile[i] * self.prim_func_node.get_space_dim()[i] //
self.output_nodes[0].get_space_dim()[i] for i in range(len(tile))
]
return tile_map
def score_block_size(self, n):
"""
Scores a block size based on its efficiency and fit relative to the architecture's warp size and SM partition.
Parameters
----------
n : int
The block size to score.
Returns
-------
Tuple[float, float]
A tuple containing two scores representing efficiency and fit, respectively.
"""
num_wrap = (n + self.arch.warp_size - 1) // self.arch.warp_size
r1 = max(num_wrap / self.arch.sm_partition, self.arch.sm_partition / num_wrap)
r2 = (num_wrap * self.arch.warp_size - n) / n
return (r1, r2)
def get_block_size(self, n):
"""
Determines the optimal block size for a given constraint, based on scoring various factors.
Parameters
----------
n : int
The constraint size.
Returns
-------
int
The optimal block size chosen from the factors of n, constrained by a maximum of 1024 and
scored by the `score_block_size` method.
"""
factors = get_all_factors(n)
factors = list(filter(lambda x: x <= 1024, factors))
factor_ordered = sorted(factors, key=self.score_block_size)
return factor_ordered[0]
def get_node_reduce_step_candidates(self, node: PrimFuncNode):
"""
Calculates reduction step candidates for each reduction axis in a PrimFuncNode. General idea : use factor first, since it does not require extra boundary check. for large prime number, which is rare case, use power of 2.
Parameters
----------
node : PrimFuncNode
The node for which to calculate reduction step candidates. It contains reduction axes (raxis)
with their domains (dom.extent).
Returns
-------
Dict[str, List[int]]
A dictionary mapping axis variable names to lists of step candidates. For each axis in the node,
this function calculates possible step sizes. For axes with a large prime domain, it uses powers of 2
as step candidates; for others, it uses all factors of the domain.
"""
results = {}
for k_iter in node.raxis:
all_factors = get_all_factors(int(k_iter.dom.extent))
if len(all_factors) == 2 and int(k_iter.dom.extent) > 64:
all_factors = [1]
while all_factors[-1] * 2 < int(k_iter.dom.extent):
all_factors.append(all_factors[-1] * 2)
results[k_iter.var.name] = all_factors
return results
def _assign_reduce_step(self, node: PrimFuncNode):
"""
Assigns an optimal reduction step for the given PrimFuncNode.
Parameters
----------
node : PrimFuncNode
The node for which the reduction step is to be assigned.
Returns
-------
Dict
A dictionary mapping reduction axis variable names to their optimal reduction steps.
"""
if node.reduction_block is None:
return {}
raxis = node.raxis
tile = [1] * len(node.get_space_dim())
all_steps = self.get_node_reduce_step_candidates(node)
def sim(a: int, b: int):
return (2 * a * b) / (a * a + b * b)
def _score(rstep_id):
rstep = {k: all_steps[k][rstep_id[k]] for k in rstep_id}
score = 0
shape = node.propagate_inputs(tile, rstep=rstep)
for i, input_buffer in enumerate(node.input_buffers):
read_transaction_elements = self.arch.transaction_size[1] // (
(node.get_buffer_dtype(input_buffer).bits + 7) // 8)
score += sim(
int(coalesced_factor(shape[i], input_buffer.shape)),
read_transaction_elements,
)
return score
def _enlarge(rstep_id):
candidates = []
candidates.append((rstep_id, _score(rstep_id)))
for ax in rstep_id:
if rstep_id[ax] + 1 == len(all_steps[ax]):
continue
r = rstep_id.copy()
r[ax] += 1
candidates.append((r, _score(r)))
best = max(candidates, key=lambda x: x[1])
return best
# enlarge rstep to ensure read is coaleased
cur_rstep_id = {ax.var.name: 0 for ax in raxis}
cur_score = _score(cur_rstep_id)
while True:
if cur_score == 0:
break
new_rstep, new_score = _enlarge(cur_rstep_id)
if new_score <= cur_score:
break
else:
cur_rstep_id, cur_score = new_rstep, new_score
rstep = {k: all_steps[k][cur_rstep_id[k]] for k in cur_rstep_id}
return rstep
def _expand_reduce_axis(self, td: TileDict):
"""
Expands the reduction axis in the TileDict based on shared memory limits.
Parameters
----------
td : TileDict
The TileDict object to be optimized.
Returns
-------
None
This function modifies the TileDict in place.
"""
smem_limit = min(self.arch.max_smem_usage // td.block_per_SM, self.arch.smem_cap)
rstep_map = td.rstep_map.copy()
def _optimize(node, rstep):
all_steps = self.get_node_reduce_step_candidates(node)
for k in all_steps:
all_steps[k] = list(filter(lambda x: x % rstep[k] == 0, all_steps[k]))
def _score(rstep_id):
rstep = {k.var.name: all_steps[k.var.name][rstep_id[k.var.name]] for k in node.raxis}
score = 0
shape = node.propagate_inputs(td.get_tile(node), rstep=rstep)
for i, input_buffer in enumerate(node.input_buffers):
score += coalesced_factor(shape[i], input_buffer.shape)
return score
def _enlarge(rstep_id):
candidates = []
for ax in rstep_id:
if rstep_id[ax] + 1 == len(all_steps[ax]):
continue
r = rstep_id.copy()
r[ax] += 1
candidates.append((r, _score(r)))
if len(candidates) == 0:
return None
return max(candidates, key=lambda x: x[1])[0]
cur_rstep_id = {
k.var.name: all_steps[k.var.name].index(rstep[k.var.name]) for k in node.raxis
}
new_rstep_map = rstep_map.copy()
while True:
new_rstep_id = _enlarge(cur_rstep_id)
if new_rstep_id is None:
break
new_rstep_map = {
k.var.name: all_steps[k.var.name][new_rstep_id[k.var.name]] for k in node.raxis
}
old_rstep_map = td.rstep_map
td.rstep_map = new_rstep_map
smem_usage, _ = self._compute_shared_memory_usage(td)
td.rstep_map = old_rstep_map
if smem_usage > smem_limit:
break
else:
cur_rstep_id = new_rstep_id
rstep = {k.var.name: all_steps[k.var.name][cur_rstep_id[k.var.name]] for k in node.raxis}
return rstep
for node in self.ordered_nodes:
if len(node.raxis) > 0:
rstep = _optimize(node, rstep_map)
rstep_map = rstep
td.rstep_map = rstep_map
td.smem_cost, td.cached_tensors_map = self._compute_shared_memory_usage(td)
def _compute_memory_traffic(self, output_tile):
"""
Computes the memory traffic for a given output tile configuration.
Parameters
----------
output_tile : List[int]
The output tile configuration.
Returns
-------
Tuple[int, Dict]
The total memory traffic and a map of operation tiles.
"""
op_tile_map = self._get_output_tile_map(output_tile)
traffic = 0
for node in reversed(self.ordered_nodes):
tile = op_tile_map[node]
input_shapes = node.propagate_inputs(tile)
output_shapes = node.propagate_outputs(tile)
for i, buffer in enumerate(node.input_buffers):
nbytes = (node.get_buffer_dtype(buffer).bits + 7) // 8
read_transaction_elements = self.arch.transaction_size[1] // nbytes
traffic += (
coalesced_tensor_shape(input_shapes[i], buffer.shape, read_transaction_elements)
* nbytes)
for i, buffer in enumerate(node.output_buffers):
nbytes = (node.get_buffer_dtype(buffer).bits + 7) // 8
write_transaction_elements = self.arch.transaction_size[0] // nbytes
traffic += (
coalesced_tensor_shape(output_shapes[i], buffer.shape,
write_transaction_elements) * nbytes)
return traffic, op_tile_map
def infer_node_smem_usage(self, td: TileDict, node: PrimFuncNode):
"""
Infers the shared memory usage of a node given a TileDict configuration.
Parameters
----------
td : TileDict
The TileDict object containing the tile configuration.
node : PrimFuncNode
The node for which to infer the shared memory usage.
Returns
-------
int
The estimated amount of shared memory used by the node.
"""
return node.footprint(td.get_tile(node), td.get_rstep(node), td.tensor_strides_map[node])
def _compute_shared_memory_usage(self, td: TileDict):
"""
Computes the stride map for a given node and TileDict configuration.
Parameters
----------
node : PrimFuncNode
The node for which to compute the stride map.
td : TileDict
The TileDict object containing the tile configuration.
Returns
-------
Tuple[Dict, Dict]
The output strides and tensor strides.
"""
self._compute_stride_map(td)
allocator = BestFit()
block_map = {}
cached_tensors_map = {}
node_internal_bytes, cached_tensors_map[self.prim_func_node] = self.infer_node_smem_usage(
td, self.prim_func_node)
block = allocator.malloc(node_internal_bytes)
allocator.free(block)
assert len(block_map) == 0
return allocator.limit, cached_tensors_map
def compute_node_stride_map(self, node: PrimFuncNode, td: TileDict):
"""
Computes the stride map for a given node based on the TileDict configuration.
Parameters
----------
node : PrimFuncNode
The node for which to compute the stride map.
td : TileDict
The TileDict object containing the tile configuration.
Returns
-------
Tuple[Dict, Dict]
A tuple of dictionaries containing the output strides and tensor strides.
"""
output_strides = {
int(i + len(node.input_buffers)): Stride() for i, _ in enumerate(node.output_buffers)
}
tensor_strides = {}
return output_strides, tensor_strides
def _compute_stride_map(self, td: TileDict):
"""
Computes the stride map for all nodes in a TileDict.
Parameters
----------
td : TileDict
The TileDict object for which to compute the stride maps.
Returns
-------
None
This function updates the TileDict object in place with the computed stride maps.
"""
output_strides_map = {}
tensor_strides_map = {}
for node in self.ordered_nodes:
output_strides_map[node], tensor_strides_map[node] = self.compute_node_stride_map(
node, td)
td.output_strides_map, td.tensor_strides_map = output_strides_map, tensor_strides_map
def compute_tile_dict(self, output_tile: List[int], rstep_map) -> TileDict:
"""
Computes and returns a TileDict object for a given output tile configuration and reduction step map.
Parameters
----------
output_tile : List[int]
The output tile configuration.
rstep_map : Dict
The reduction step map.
Returns
-------
TileDict
A TileDict object containing the computed tile configuration, memory traffic, shared memory cost,
grid size, and other related parameters.
"""
td = TileDict(output_tile)
td.rstep_map = rstep_map
td.traffic, td.tile_map = self._compute_memory_traffic(output_tile)
td.smem_cost, td.cached_tensors_map = self._compute_shared_memory_usage(td)
if td.smem_cost > self.arch.smem_cap:
td.valid = False
return td
output_shape = self.output_nodes[0].get_space_dim()
td.grid_size = int(np.prod([(y + x - 1) // x for x, y in zip(output_tile, output_shape)]))
# estimated reg usage
reg_usage = int(2 * max([
np.prod(td.get_tile(node)) * node.get_dtype().bits / 32 for node in self.ordered_nodes
]))
if reg_usage > self.arch.reg_cap:
td.valid = False
return td
td.block_per_SM = min(
self.arch.max_smem_usage // max(td.smem_cost, 1),
self.arch.reg_cap // max(reg_usage, 1),
self.arch.sm_partition,
)
td.num_wave = int(np.ceil(td.grid_size / int(td.block_per_SM * self.arch.compute_max_core)))
return td
def check_tile_shape_isvalid(self, td: TileDict) -> bool:
"""
Checks if the tile shapes in the TileDict are valid for the nodes in this context.
Parameters:
- td (TileDict): The TileDict object containing tile shapes and other configurations.
Returns:
- bool: True if all tile shapes are valid, False otherwise.
"""
for node in self.ordered_nodes:
if np.prod(td.get_tile(node)) == 0:
return False
node_grid_size = np.prod([
(y + x - 1) // x for x, y in zip(td.get_tile(node), node.get_space_dim())
])
if node_grid_size != td.grid_size:
return False
if (hasattr(node, "reduce_op") and node.reduce_op is not None and
len(node.reduce_op.axis) == len(td.output_tile)):
for i, tile_extent in enumerate(td.output_tile):
if node.reduce_op.axis[i].dom.extent % tile_extent:
return False
return True
def recommend_block_size(self, td: TileDict) -> List[int]:
"""
Recommends optimal block sizes based on the TileDict configuration.
Parameters
----------
td : TileDict
The TileDict object containing the tile configuration.
Returns
-------
List[int]
A list of recommended block sizes sorted based on their score.
"""
node_space_sizes = [int(np.prod(td.get_tile(node))) for node in self.ordered_nodes]
max_block_size = functools.reduce(math.gcd, node_space_sizes)
if max_block_size < self.arch.warp_size * self.arch.sm_partition and max_block_size == min(
node_space_sizes):
node_reduce_sizes = [
int(np.prod(list(td.get_rstep(node).values()))) for node in self.ordered_nodes
]
total_sizes = [x * y for x, y in zip(node_space_sizes, node_reduce_sizes)]
max_possible_size = functools.reduce(math.gcd, total_sizes)
possible_block_sizes = list(
filter(
lambda x: x % max_block_size == 0 and x <= 1024,
get_all_factors(max_possible_size),
))
possible_block_sizes = list(
filter( # either be a factor of space or cover fully cover the space
lambda x: all([x % s == 0 or s % x == 0 for s in node_space_sizes]),
possible_block_sizes,
))
factor_ordered = sorted(possible_block_sizes, key=self.score_block_size)
return factor_ordered
else:
possible_block_sizes = get_all_factors(max_block_size)
possible_block_sizes = list(filter(lambda x: x <= 1024, possible_block_sizes))
factor_ordered = sorted(possible_block_sizes, key=self.score_block_size)
return factor_ordered
def assign_block_size(self, td: TileDict, topk=1):
"""
Assigns block sizes to the TileDict based on the recommended block sizes.
Parameters
----------
td : TileDict
The TileDict object to assign block sizes to.
topk : int, optional
The number of top block sizes to consider.
Yields
-------
Dict
The block size assignment for the primary function node.
"""
block_size_ordered = self.recommend_block_size(td)
for block_size in block_size_ordered:
result = {}
failed = False
result = self._assign_block_size(self.prim_func_node, td, block_size)
if result is None:
failed = True
break
if failed:
continue
else:
yield result
topk -= 1
if topk == 0:
break
def _assign_block_size(self, node: PrimFuncNode, td: TileDict, block_size: int):
"""
Assigns a block size to a given PrimFuncNode based on the TileDict configuration and the specified block size.
Parameters
----------
node : PrimFuncNode
The node to assign the block size to.
td : TileDict
The TileDict object containing the tile configuration.
block_size : int
The block size to be assigned.
Returns
-------
Hint
A Hint object containing the assigned block size and other related settings.
"""
tile, rsteps = td.get_tile(node), td.get_rstep(node)
factors = factorize(block_size)
cur_threads = [1 for _ in tile]
reduce_thread = {k: 1 for k in rsteps}
ndim = len(tile)
def _score(node, thread): # small is better
score = 0
block_tile = [int(np.ceil(tile[i] / thread[i])) for i in range(ndim)]
shape = node.propagate_inputs(block_tile)
for i, _ in enumerate(node.input_buffers):
score += np.prod(shape[i]) / self.arch.bandwidth[1]
for buffer in node.output_buffers:
score += coalesced_tensor_shape(thread, buffer.shape, 8) / self.arch.bandwidth[0]
return score
for factor in reversed(factors):
score_map = {}
for i in range(ndim):
if cur_threads[i] >= tile[i]:
continue
if (tile[i] % (cur_threads[i] * factor)) != 0:
continue
cur_threads[i] *= factor
score_map[i] = (_score(node, cur_threads), i)
cur_threads[i] //= factor
if len(score_map) > 0:
# assign to space axis
dim_order = sorted(score_map.keys(), key=lambda x: score_map[x])
cur_threads[dim_order[0]] *= factor
else:
# assign to reduce axis
target_ax = None
for ax, ax_len in reversed(list(rsteps.items())):
if ax_len % (reduce_thread[ax] * factor) == 0:
target_ax = ax
break
assert target_ax
reduce_thread[target_ax] *= factor
codegen_dict = Hint()
codegen_dict.block = tile
codegen_dict.thread = cur_threads
codegen_dict.rstep = [rsteps[ax.var.name] for ax in node.raxis]
codegen_dict.reduce_thread = [reduce_thread[ax.var.name] for ax in node.raxis]
codegen_dict.cached_tensors = td.cached_tensors_map[node]
codegen_dict.rasterization_plan = self.plan_rasterization(td)
if node.get_dtype().bits == 16: # set step=2 for 16bit case to ensure coalesced access
codegen_dict._step = [1 for _ in range(ndim)]
for i in reversed(range(ndim)):
if codegen_dict.block[i] // codegen_dict.thread[i] % 2 == 0:
codegen_dict._step[i] = 2
break
elif node.get_dtype().bits == 8: # set step=4 for 8bit case to ensure coalesced access
codegen_dict._step = [1 for _ in range(ndim)]
for i in reversed(range(ndim)):
if codegen_dict.block[i] // codegen_dict.thread[i] % 4 == 0:
codegen_dict._step[i] = 4
break
# Plan vectorize
codegen_dict.vectorize = self._plan_vectorize(node, td, block_size)
codegen_dict.arch = self.arch
codegen_dict.opt_shapes = self.prim_func_node.get_tag("opt_shapes")
return codegen_dict
def _plan_vectorize(self, node: PrimFuncNode, td: TileDict, block_size: int):
"""
Plans vectorization for a given PrimFuncNode based on the TileDict configuration and block size.
Parameters
----------
node : PrimFuncNode
The node for which to plan vectorization.
td : TileDict
The TileDict object containing the tile configuration.
block_size : int
The block size used for vectorization planning.
Returns
-------
Dict
A dictionary mapping tensors to their vectorization size.
"""
def is_cont(shape, vec):
if len(shape) == 0:
return vec == 1
last = shape[-1]
if last == 1:
return is_cont(shape[0:-1], vec // last)
else:
return last % vec == 0
def is_shape_aligned(shape, factor):
return int(np.prod(shape)) % factor == 0
def is_type_allowed(dtype, vec):
return dtype.bits * vec <= 128
vectorize_sizes = [16, 8, 4, 2]
dtypes = node.get_reduce_inputs_dtype()
shapes = node.propagate_reduction_inputs(td.get_tile(node), td.get_rstep(node))
vectorize_result = {}
for tensor, shape in shapes.items():
for v in vectorize_sizes:
if (is_shape_aligned(shape, block_size * v) and is_cont(shape, v) and
is_type_allowed(dtypes[tensor], v)):
vectorize_result[tensor] = v
break
return vectorize_result
def plan_rasterization(self, td: TileDict): # pylint: disable=unused-argument
"""
Plans the rasterization for the given TileDict. This function is not implemented yet.
Parameters
----------
td : TileDict
The TileDict object to plan rasterization for.
Raises
-------
RasterRationPlan
This function is not implemented yet.
"""
return NoRasterization()
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Policy for tensorcore schedule"""
import tvm
from typing import Dict, List, Tuple, Optional
import numpy as np
import logging
from ...arch import TileDevice
from ..hint import Hint, Stride, TileDict, IntrinInfo
from ..node import PrimFuncNode
from .common import coalesced_factor, factorize, get_all_factors
from .default import DefaultPolicy
from ..rasterization import NoRasterization, Rasterization2DColumn
logger = logging.getLogger(__name__)
class TensorCorePolicy(DefaultPolicy):
def __init__(self,
func: tvm.tir.PrimFunc,
arch: TileDevice,
tags: Optional[Dict] = None) -> None:
super().__init__(func, arch, tags)
# this is the trick for wmma.
# However, for int8 mma, the wmma_k should be 32.
self.wmma_k = 16
self.pipeline_stage: int = 1
self.use_async_copy: bool = False
self.block_reduction_depth: Optional[int] = None
self._legalize_info()
def _legalize_info(self):
pipleline_stage = self.prim_func_node.get_tag("pipeline_stage")
if pipleline_stage:
self.pipeline_stage = pipleline_stage
else:
if self.arch.compute_capability == "sm_80":
self.pipeline_stage = 2
else:
self.pipeline_stage = 1
use_async_copy = self.prim_func_node.get_tag("use_async_copy")
if use_async_copy:
self.use_async_copy = use_async_copy
else:
if self.arch.compute_capability == "sm_80":
self.use_async_copy = True
else:
self.use_async_copy = False
# TODO: block reduction depth is not used for now.
# As there still exists some performance issues for block reduction.
block_reduction_depth = self.prim_func_node.get_tag("block_reduction_depth")
if block_reduction_depth:
self.block_reduction_depth = block_reduction_depth
def _compute_tc_strides(
self,
node: PrimFuncNode,
tile: List[int],
rstep: Optional[Dict[str, int]] = None,
) -> Tuple[Stride, Stride, Stride]:
if rstep is None:
rstep = {}
# strides was used for shared memory padding. which is necessary for avoiding
# shared memory load bank conflict when we do not applying tensorcore layout.
shapes = node.propagate_reduction_inputs(tile, rstep)
AS_shape, BS_shape = shapes.values()
CS_shape = tile
A_ax_m, A_ax_k, B_ax_k, B_ax_n, C_ax_m, C_ax_n = node.infer_tensorcore_axis()
# applying strides
# TODO(leiwang1999): offset should be dynamically set. we can use tag -> enable_offset to control this option..
offset = 8
A_high_ax = min(A_ax_m, A_ax_k)
B_high_ax = min(B_ax_n, B_ax_k)
C_high_ax = min(C_ax_m, C_ax_n)
A_stride = Stride(stride=np.prod(AS_shape[A_high_ax + 1:]) + offset, ax=A_high_ax)
B_stride = Stride(stride=np.prod(BS_shape[B_high_ax + 1:]) + offset, ax=B_high_ax)
C_stride = Stride(stride=np.prod(CS_shape[C_high_ax + 1:]) + offset, ax=C_high_ax)
return A_stride, B_stride, C_stride
def infer_node_smem_usage(self, td: TileDict, node: PrimFuncNode):
value, cached_tensors = super().infer_node_smem_usage(td, node)
value *= self.pipeline_stage
return value, cached_tensors
def _assign_reduce_step(self, node):
if not node.get_tag("tensorcore_config"):
return super()._assign_reduce_step(node)
# get reduce input size
target_transaction = self.arch.transaction_size[0] * 2
# 512 bytes // type bits
reduce_input_dtype = node.get_buffer_dtype(
node.block_analyzer.get_input_buffers(node.reduction_block)[0])
basic = (target_transaction * 8) // reduce_input_dtype.bits
result = {}
for iter_info in node.raxis:
iter_name = iter_info.var.name
iter_dom = iter_info.dom.extent
if iter_dom % 16 > 0:
result[iter_name] = (16 if iter_dom < basic else basic) # for the case of padding
elif iter_dom % basic == 0:
result[iter_name] = basic
else:
return super()._assign_reduce_step(node)
return result
def _expand_reduce_axis(self, td: TileDict):
# For tensorcore program, if we got a small tilesize, we should consider expand the reduce axis
# to improve compute efficiency.
def _check_small_tile(td: TileDict):
minimal_threadhold = 32
for node in self.ordered_nodes:
tile = td.get_tile(node)
if any([t <= minimal_threadhold for t in tile]):
return True
return False
if _check_small_tile(td):
smem_limit = min(self.arch.max_smem_usage // td.block_per_SM, self.arch.smem_cap)
rstep_map = td.rstep_map.copy()
def _optimize(node, rstep):
all_steps = self.get_node_reduce_step_candidates(node)
# todo(lei): optimize the all_steps enlarge policy to be a multiple of the original all_steps[k]
for k in all_steps:
all_steps[k] = list(filter(lambda x: x % rstep[k] == 0, all_steps[k]))
if any([v == [] for v in all_steps.values()]):
return rstep
def _shared_memory_usage(td: TileDict):
return node.footprint(td.output_tile, new_rstep_map,
td.tensor_strides_map[node])
def _score(rstep_id):
rstep = {
k.var.name: all_steps[k.var.name][rstep_id[k.var.name]] for k in node.raxis
}
score = 0
shape = node.propagate_inputs_on_reduction(td.get_tile(node), rstep=rstep)
input_buffers = node.block_analyzer.get_input_buffers(node.reduction_block)
for i, input_buffer in enumerate(input_buffers):
score += coalesced_factor(shape[i], input_buffer.shape)
return score
def _enlarge(rstep_id):
candidates = []
for ax in rstep_id:
if rstep_id[ax] + 1 == len(all_steps[ax]):
continue
r = rstep_id.copy()
r[ax] += 1
candidates.append((r, _score(r)))
if len(candidates) == 0:
return None
return max(candidates, key=lambda x: x[1])[0]
cur_rstep_id = {
k.var.name: all_steps[k.var.name].index(rstep[k.var.name]) for k in node.raxis
}
new_rstep_map = rstep_map.copy()
while True:
new_rstep_id = _enlarge(cur_rstep_id)
if new_rstep_id is None:
break
new_rstep_map = {
k.var.name: all_steps[k.var.name][new_rstep_id[k.var.name]]
for k in node.raxis
}
old_rstep_map = td.rstep_map
td.rstep_map = new_rstep_map
smem_usage, _ = _shared_memory_usage(td)
td.rstep_map = old_rstep_map
if smem_usage > smem_limit:
break
else:
cur_rstep_id = new_rstep_id
rstep = {
k.var.name: all_steps[k.var.name][cur_rstep_id[k.var.name]] for k in node.raxis
}
return rstep
for node in self.ordered_nodes:
if len(node.raxis) > 0:
rstep = _optimize(node, rstep_map)
rstep_map = rstep
td.rstep_map = rstep_map
td.smem_cost, td.cached_tensors_map = self._compute_shared_memory_usage(td)
if self.block_reduction_depth is not None:
def _expand_with_tags(rstep):
new_rstep = {k: v * self.block_reduction_depth for k, v in rstep.items()}
return new_rstep
rstep_map = td.rstep_map.copy()
for node in self.ordered_nodes:
if len(node.raxis) > 0:
rstep = _expand_with_tags(rstep_map)
rstep_map = rstep
td.rstep_map = rstep_map
return
def get_node_reduce_step_candidates(self, node):
if not node.get_tag("tensorcore_config"):
return super().get_node_reduce_step_candidates(node)
else:
# must be a a multiple of wmma_k
return {
k.var.name: [
x * self.wmma_k for x in get_all_factors(int(k.dom.extent) // self.wmma_k)
] for k in node.raxis
}
def check_tile_shape_isvalid(self, td: TileDict):
for node in self.ordered_nodes:
if node.get_tag("tensorcore_config"):
ax_m, ax_n = node.get_tag("tensorcore_config")
block_m, block_n = (
td.tile_map[node][ax_m],
td.tile_map[node][ax_n],
)
# check the tile size is valid
wmma_invalid = [
block_m < wmma_m or block_n < wmma_n
for wmma_m, wmma_n in self.arch.get_avaliable_tensorintrin_shapes()
]
if all(wmma_invalid):
return False
if any([y % x for x, y in zip(td.tile_map[node], node.get_space_dim())]):
return False
return super().check_tile_shape_isvalid(td)
def _can_implement_layout(self, node: PrimFuncNode, td: TileDict):
# Not implemented yet
# This function is used to check whether we can implement swizzling
# layout under this tile config
return False
def compute_node_stride_map(self, node: PrimFuncNode, td: TileDict):
if not node.get_tag("tensorcore_config"):
return super().compute_node_stride_map(node, td)
use_layout = self._can_implement_layout(node, td)
AS_stride, BS_stride, C_stride = self._compute_tc_strides(node, td.get_tile(node),
td.get_rstep(node))
A_stride, B_stride, _ = self._compute_tc_strides(node, td.get_tile(node))
tensor_strides = {}
output_strides = {
int(i + len(node.input_buffers)): Stride() for i, _ in enumerate(node.output_buffers)
}
tensor_strides = {}
# when connected to shared input, should use full stride without rstep
for i, (_, _) in enumerate(zip([AS_stride, BS_stride], [A_stride, B_stride])):
if use_layout:
continue
_ = node.block_analyzer.get_input_buffers(node.reduction_block)[i].name
# TODO(lei): should dig further for shared memory connection case.
return output_strides, tensor_strides
def _assign_block_size(self, node: PrimFuncNode, td: TileDict, block_size: int):
if not node.get_tag("tensorcore_config"):
return super()._assign_block_size(node, td, block_size)
ax_m, ax_n = node.get_tag("tensorcore_config")
if block_size % self.arch.warp_size != 0:
return None
tile, rsteps = td.get_tile(node), td.get_rstep(node)
warps = block_size // self.arch.warp_size
ndim = len(tile)
wmma = self.arch.get_avaliable_tensorintrin_shapes()[-1]
wmma_tile = [1 for _ in range(ndim)]
wmma_tile[ax_m] = wmma[0]
wmma_tile[ax_n] = wmma[1]
space = [tile[i] // wmma_tile[i] for i in range(ndim)]
if tile[ax_m] < wmma_tile[ax_m] or tile[ax_n] < wmma_tile[ax_n]:
# allow pad, otherwise, we can not get a valid tile shape
return None
factors = factorize(np.prod(space) // warps)
def _score(node, thread): # small is better
score = 0
block_tile = [int(np.ceil(tile[i] / thread[i])) for i in range(ndim)]
shape = node.propagate_inputs_on_reduction(block_tile)
input_buffers = node.block_analyzer.get_input_buffers(node.reduction_block)
for i, _ in enumerate(input_buffers):
score += np.prod(shape[i]) / self.arch.bandwidth[1]
return score
warp_tile = wmma_tile.copy()
for factor in reversed(factors):
score_map = {}
for i in range(ndim):
if tile[i] % (warp_tile[i] * factor) != 0:
continue
warp_tile[i] *= factor
score_map[i] = (_score(node, warp_tile), i)
warp_tile[i] //= factor
if len(score_map) == 0:
return None
dim_order = sorted(score_map.keys(), key=lambda x: score_map[x])
warp_tile[dim_order[0]] *= factor
codegen_dict = Hint()
codegen_dict.block = tile
codegen_dict.warp = warp_tile
codegen_dict.use_tc = True
codegen_dict.pipeline_stage = self.pipeline_stage
codegen_dict.block_reduction_depth = self.block_reduction_depth
codegen_dict.use_async = self.use_async_copy
codegen_dict.rstep = [int(rsteps[ax.var.name]) for ax in node.raxis]
codegen_dict.cached_tensors = td.cached_tensors_map[node]
codegen_dict.rasterization_plan = self.plan_rasterization(td)
intrin_info = node.get_tag("intrin_info")
if intrin_info:
codegen_dict.intrin_info = IntrinInfo(**intrin_info)
if intrin_info["out_dtype"] in ["float32"]:
codegen_dict.shared_scope = "shared.dyn"
# smem capacity
# TODO: This is a dummy mul which avoid reusing some shared memory.
# Should be removed in the future.
if td.smem_cost > (self.arch.smem_cap):
# Tile Dict: {td.output_tile} Shared memory exceeds the static capacity
# use dynamic shared memory.
codegen_dict.shared_scope = "shared.dyn"
codegen_dict.shared_scope = "shared.dyn"
codegen_dict.complete_config(node)
codegen_dict.vectorize = self._plan_vectorize(self.prim_func_node, td, block_size)
codegen_dict.arch = self.arch
codegen_dict.opt_shapes = self.prim_func_node.get_tag("opt_shapes")
codegen_dict.tensorcore_legalization()
return codegen_dict
def plan_rasterization(self, td: TileDict):
conditions = []
# only support single node for now
conditions.append(len(self.ordered_nodes) > 1)
# only on Ampere+ arch
conditions.append(self.arch.compute_capability < "80")
def _check_memory_size():
overall_gmem_size_in_bytes: int = 0
for node in self.ordered_nodes:
for buffer in node.input_buffers:
overall_gmem_size_in_bytes += (
int(np.prod(buffer.shape)) * tvm.DataType(buffer.dtype).bits // 8)
return overall_gmem_size_in_bytes < self.arch.l2_cache_size_bytes
conditions.append(_check_memory_size())
if any(conditions):
return NoRasterization()
# otherwise, simply provide a block rasterization factor
raster_factor = int(self.arch.compute_max_core**0.5)
return Rasterization2DColumn(raster_factor)
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Rasteration Plan For L2 Cache Locality"""
from typing import List
class Rasterization:
panel_width_ = None
def __init__(self) -> None:
pass
def get_code(self) -> List[str]:
raise NotImplementedError()
@property
def panel_width(self):
assert self.panel_width_ is not None
return self.panel_width_
class NoRasterization(Rasterization):
def __init__(self) -> None:
super().__init__()
def __repr__(self) -> str:
return "<NoRasterization>"
def get_code(self) -> List[str]:
return []
class Rasterization2DRow(Rasterization):
"""
Rasterization by Row, each Row line width is panel_width
_________
_________|
|_________
__________|
"""
def __init__(self, panel_width=4) -> None:
super().__init__()
self.panel_width_ = panel_width
def __repr__(self) -> str:
return f"<Rasterization2DRow({self.panel_width_})>"
def get_code(self) -> List[str]:
raise NotImplementedError()
class Rasterization2DColumn(Rasterization):
"""
Rasterization by Column, each column line width is panel_width
_
| | | |
| | | |
|_| |_|
"""
def __init__(self, panel_width=4) -> None:
super().__init__()
self.panel_width_ = panel_width
def __repr__(self) -> str:
return f"<Rasterization2DColumn({self.panel_width_})>"
def get_device_function(self) -> str:
return """
__device__ __inline__ dim3 rasterization2DColumn(const int panel_width) {
const auto baseBlockIdx = blockIdx.x + gridDim.x *blockIdx.y;
const auto totalPanel = (gridDim.x * gridDim.y +panel_width * gridDim.x - 1) / (panel_width * gridDim.x);
const auto totalBlock = gridDim.x * gridDim.y;
const auto panelIdx = baseBlockIdx / (panel_width *gridDim.x);
const auto strideLd = panelIdx + 1 < totalPanel ?panel_width : (totalBlock - panelIdx * (panel_width *gridDim.x)) / gridDim.x;
const auto bx = (panelIdx & 1) ? gridDim.x -(baseBlockIdx - panelIdx * panel_width * gridDim.x) /strideLd - 1 : (baseBlockIdx - panelIdx * panel_width *gridDim.x) / strideLd;
const auto by = (baseBlockIdx - panelIdx * panel_width *gridDim.x) % strideLd + panelIdx * panel_width;
const auto bz = blockIdx.z;
dim3 blockIdx(bx, by, bz);
return blockIdx;
}
"""
def get_code(self, panel_width: int = None) -> List[str]:
if panel_width is None:
panel_width = self.panel_width_
return [
self.get_device_function(),
"const dim3 blockIdx = rasterization2DColumn({});\n".format(panel_width),
]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment