Commit cd191889 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Carver] Introduce a tile-structure based cost model for auto tuning (#70)

* [Enhancement] Add VectorizeLoop function and update imports for compatibility

* [CI][Test] Improve test cases for vectorization and fix typos in parser comments

* lint fix

* Fix incorrect module reference for VectorizeLoop transformation

* Refactor vectorize_loop transformation by removing unused extent mutation logic

* [Enhancement] Add support for FP8 data types and global barriers in CUDA codegen

* Fix formatting in CUDA FP8 header file for consistency

* Refactor CI workflow to use 'tilelang_ci' virtual environment and update CUDA type printing for better clarity

* Update submodule 'tvm' to latest commit for improved functionality

* Refactor execution backend references from 'dl_pack' to 'dlpack' for consistency and clarity; add apply_simplify function to simplify PrimFunc or IRModule.

* Refactor CUDA code for improved readability; clean up formatting and remove unnecessary whitespace in multiple files.

* Refactor import statement in test_tilelang_kernel_dequantize_gemm.py to use 'tilelang.language' for consistency

* Add CUDA requirements to FP8 test cases and update references for clarity

* Add a blank line for improved readability in test_tilelang_kernel_fp8_gemm_mma.py

* Fix data type in reference result calculation for consistency in test_tilelang_kernel_gemm_mma_intrinsic.py

* Add CUDA requirements and FP8 test cases for matmul and gemv simulations

* Remove debug print statements and use tilelang's testing assertion for result validation in test_tilelang_kernel_gemm_mma_intrinsic.py

* Remove outdated comment regarding FP8 tests in test_tilelang_kernel_gemv_simt.py

* Add BF16 support to matrix multiplication and introduce corresponding test cases

* Add a blank line for improved readability in BF16 GEMM test

* Update acknowledgements in README to include supervision by Zhi Yang at Peking University

* enhance acknowledgement

* Replace tutorial on memory layout optimization with new tutorial on writing high-performance kernels with thread primitives

* Update subproject commit for TVM dependency

* Update subproject commit for TVM dependency

* Add int4_t type and functions for packing char values in CUDA common header

* Add plot_layout example and implement GetForwardVars method in layout classes

* Refactor code for improved readability by adjusting line breaks and formatting in layout and test files

* Fix formatting by removing unnecessary line break in layout.h

* Refactor make_int4 function for improved readability by adjusting parameter formatting

* Add legend to plot_layout for improved clarity of thread and local IDs

* Remove unnecessary dependencies from requirements files for cleaner setup

* Remove flash_mha.py and add .gitkeep to deepseek_mla directory

* Add build requirements and update installation scripts for improved setup

* Introduce carver

* Refactor imports and improve code formatting for consistency

* Add unit tests for carver recommendation hints

* lint fix

* Enhance ElementwiseTemplate and BaseTemplate with detailed docstrings for improved code documentation and clarity

* Refactor import statements and clean up whitespace in template files for improved readability

* Add README.md for Carver framework with usage examples and architecture support
parent 2411fa28
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import tilelang.testing
from tilelang import carver
from tilelang.carver.arch import auto_infer_current_arch
from typing import List
def run_general_reduction_recommend_hints(structure: str = "SSR",
shape: List[int] = None,
dtype: str = "float16",
topk: int = 20):
arch = auto_infer_current_arch()
carve_template = carver.GeneralReductionTemplate(
structure=structure,
shape=shape,
dtype=dtype,
).with_arch(arch)
func = carve_template.equivalent_function()
assert func is not None, "Function is None"
hints = carve_template.recommend_hints(topk=topk)
assert len(hints) > 0, "Hints length is zero"
def test_general_reduction_recommend_hints():
run_general_reduction_recommend_hints("SSR", [1024, 1024, 1024], "float16")
run_general_reduction_recommend_hints("SS", [1024, 1024], "float16")
run_general_reduction_recommend_hints("SRS", [1024, 1024, 1024], "float16")
def run_elementwise_recommend_hints(shape: List[int] = None,
dtype: str = "float16",
topk: int = 20):
arch = auto_infer_current_arch()
carve_template = carver.ElementwiseTemplate(
shape=shape,
dtype=dtype,
).with_arch(arch)
func = carve_template.equivalent_function()
assert func is not None, "Function is None"
hints = carve_template.recommend_hints(topk=topk)
assert len(hints) > 0, "Hints length is not topk"
def test_elementwise_recommend_hints():
run_elementwise_recommend_hints([1024, 1024], "float16")
run_elementwise_recommend_hints([1024], "float16")
run_elementwise_recommend_hints([1024, 1024, 1024], "float16")
def run_matmul_recommend_hints(
M: int = 1024,
N: int = 1024,
K: int = 1024,
in_dtype: str = "float16",
out_dtype: str = "float16",
accum_dtype: str = "float16",
):
arch = auto_infer_current_arch()
carve_template = carver.MatmulTemplate(
M=M,
N=N,
K=K,
in_dtype=in_dtype,
out_dtype=out_dtype,
accum_dtype=accum_dtype,
).with_arch(arch)
func = carve_template.equivalent_function()
assert func is not None, "Function is None"
hints = carve_template.recommend_hints(topk=20)
assert len(hints) > 0, "Hints length is not 20"
def test_matmul_recommend_hints():
run_matmul_recommend_hints(1024, 1024, 1024, "float16", "float16", "float16")
run_matmul_recommend_hints(1024, 1024, 1024, "int8", "int32", "int32")
run_matmul_recommend_hints(1024, 1024, 1024, "float16", "float32", "float16")
def run_gemv_recommend_hints(N: int = 1024,
K: int = 1024,
in_dtype: str = "float16",
out_dtype: str = "float16",
accum_dtype: str = "float16"):
arch = auto_infer_current_arch()
carve_template = carver.GEMVTemplate(
N=N,
K=K,
in_dtype=in_dtype,
out_dtype=out_dtype,
accum_dtype=accum_dtype,
).with_arch(arch)
func = carve_template.equivalent_function()
assert func is not None, "Function is None"
hints = carve_template.recommend_hints(topk=20)
assert len(hints) > 0, "Hints length is not 20"
def test_gemv_recommend_hints():
run_gemv_recommend_hints(1024, 1024, "float16", "float16", "float16")
run_gemv_recommend_hints(1024, 1024, "int8", "int32", "int32")
run_gemv_recommend_hints(1024, 1024, "float16", "float32", "float16")
if __name__ == "__main__":
tilelang.testing.main()
# Carver: A Tile-Structure Based Hint Recommend Framework for Machine Learning Compilers
**Carver** is a lightweight framework for generating and ranking tile configurations (also known as **tiling strategies**, **blocking schemes**, or **scheduling hints**) for common GPU, CPU, and accelerator backends. It helps you explore efficient mappings of loops for operations such as matrix multiplication, elementwise transforms, and other reduction-oriented kernels.
Carver combines hardware architecture information, user-defined tile structures, and built-in heuristics to recommend tiling strategies (or "hints"). The recommended hints are easily adaptable to multiple backends, including [TVM](https://tvm.apache.org/), [triton](https://github.com/openai/triton), [tilelang](https://github.com/LeiYanggh/tilelang) (or other domain-specific compilers).
---
### Key Features
- **Unified Tiling Framework**: Generate tile candidates for multiple backends under a unified API.
- **Architecture-Specific Modeling**: Take into account architecture constraints (e.g., CUDA `smem_cap`, warp size, CPU cache structure, etc.) when generating hints.
- **Flexible Templates**: High-level templates (like `MatmulTemplate`, `GeneralReductionTemplate`, `ElementwiseTemplate`) let you concisely specify kernel structures.
- **Extendable**: Easily add support for new backends and new operation templates.
---
## Usage Examples
### Basic Usage: General Reduction Template
Once installed tilelang, you can import Carver and start creating templates:
```python
from tilelang import carver
from tilelang.carver.arch import CUDA
# Instantiate a CUDA device object for an RTX 4090
arch = CUDA("nvidia/geforce-rtx-4090")
# Create a general reduction template for a loop nest:
# for i in Spatial(1024):
# for j in Spatial(1024):
# for k in Reduce(1024):
# ...
carve_template = carver.GeneralReductionTemplate(
structure="SSR",
shape=[1024, 1024, 1024],
dtype="float16",
).with_arch(arch)
# Generate top 20 tile candidates (aka scheduling hints)
hints = carve_template.recommend_hints(topk=20)
for hint in hints:
print(hint)
```
**Example Output** (truncated):
```python
{
'block': [1, 128],
'thread': [1, 128],
'rstep': [64],
...
},
{
'block': [2, 64],
'thread': [2, 64],
'rstep': [64],
...
},
...
{
'block': [1, 16],
'thread': [1, 16],
'rstep': [512],
'reduce_thread': [8],
...
}
```
A tile structure composed of S and R can simulate various cases. For example, structure `SS` represents a 2D element-wise operation, while `SSR` can represent a general matrix multiplication.
We can specialize more advanced templates to provide finer-grained information, such as `MatmulTemplate`.
### Matmul Template
Carver also provides a specialized `MatmulTemplate` for matrix multiplication (e.g., `C = A * B`), automatically inferring common tiling strategies (thread blocks, warps, use of tensor cores, etc.).
```python
from tilelang import carver
from tilelang.carver.arch import CUDA
arch = CUDA("nvidia/geforce-rtx-4090")
carve_template = carver.MatmulTemplate(
M=1024,
N=1024,
K=1024,
in_dtype="float16",
accum_dtype="float16",
out_dtype="float16",
).with_arch(arch)
# Retrieve the (symbolic) function describing the matmul
func = carve_template.equivalent_function()
print("Equivalent Function:\n", func)
# Generate hints
hints = carve_template.recommend_hints(topk=20)
for hint in hints:
print(hint)
```
**Example Output**:
```python
{
'block': [32, 64],
'warp': [16, 32],
'rstep': [128],
'use_tc': True,
...
},
{
'block': [64, 32],
'warp': [32, 16],
'rstep': [128],
'use_tc': True,
...
},
...
{
'block': [256, 32],
'warp': [128, 16],
'rstep': [32],
'use_tc': True,
...
}
```
---
## Supported Architectures
Carver currently provides out-of-the-box support for:
- **CUDA**: e.g., `arch = CUDA("nvidia/geforce-rtx-4090")`
- **CDNA** (AMD GPU-like backends)
- **CPU**
Adding a new architecture is as simple as implementing a new subclass of `TileDevice` or providing a custom target that describes:
- Shared/local memory capacity
- Warp (or vector) size
- Cache sizes
- Tensor instructions available
Below is an **illustrative snippet** of the CUDA backend:
```python
class CUDA(TileDevice):
def __init__(self, target: Union[tvm.target.Target, str]):
...
self.platform = "CUDA"
# Device constraints
self.smem_cap = device.max_shared_memory_per_block
self.compute_max_core = device.multi_processor_count
self.warp_size = device.warp_size
...
self.transaction_size = [32, 128] # bytes
self.bandwidth = [750, 12080] # MB/s, approximate
self.available_tensor_instructions = None
def get_avaliable_tensorintrin_shapes(self):
self.available_tensor_instructions = (
TensorInstruction("mma", [16, 16]),
TensorInstruction("wmma", [16, 16]),
)
return [t.shape for t in self.available_tensor_instructions]
def __repr__(self):
return f"CUDA({self.target})"
```
## Adapting Hints to Other Compilers
One of Carver’s main benefits is its adaptability. Here are a examples for triton lang:
Given a Carver hint like:
```python
{
'block': [32, 64],
'warp': [16, 32],
'rstep': [128],
'use_tc': True,
'vectorize': {'A_reindex': 8, 'B_reindex': 8}
}
```
You might interpret this in **Triton** as:
- `block_m = 32, block_n = 64, block_k = 128`
- Potential warp usage = `warp_m = 16, warp_n = 32`
- `vectorize`: load data with a vector width of 8
- If `use_tc` is true, consider using Tensor Cores (TensorOps in Triton) if supported.
This helps quickly test multiple configurations without manually guessing.
## Supported Templates
Carver abstracts common loop patterns through templates:
- **`GeneralReductionTemplate`**: For general `Spatial-Spatial-Reduce` (SSR) structures or similar.
- **`MatmulTemplate`**: For standard matrix multiplication `C = A * B`.
- **`GEMVTemplate`**: For `y = Ax` or `y = xA` style operations.
- **`ElementwiseTemplate`**: For elementwise transformations or pointwise ops.
You can also create your own specialized templates if you have unique loop structures or constraints. For instance, you might define specialized templates for convolution, flash attention, etc.
## TODO Items
- [ ] **Flash Attention** and its variants: Support search-space generation for specialized attention kernels.
- [ ] **Adapt to tile language**: Provide ready-made scheduling calls or wrappers for [tilelang](https://github.com/LeiYanggh/tilelang) to streamline end-to-end integration.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Base infra"""
from .analysis import (
BlockInfo, # noqa: F401
IterInfo, # noqa: F401
collect_block_iter_vars_used_in_access_region, # noqa: F401
collect_vars_used_in_prim_expr, # noqa: F401
detect_dominant_read, # noqa: F401
is_broadcast_epilogue, # noqa: F401
normalize_prim_func, # noqa: F401
) # noqa: F401
from .common_schedules import get_block, get_output_blocks, try_inline, try_inline_contiguous_spatial # noqa: F401
from .roller import *
from .arch import CUDA, CDNA # noqa: F401
from .template import MatmulTemplate, GEMVTemplate, ElementwiseTemplate, GeneralReductionTemplate # noqa: F401
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Analysis on TIR blocks, loops and functions."""
from typing import List, Optional, Set, Union
from typing_extensions import Literal
from tvm import ir, tir, DataType
from tvm._ffi import get_global_func
from tvm.target.target import Target
from tvm.tir import Schedule, IterVar
from tvm.tir.schedule import BlockRV
class IterInfo:
"""Information about a loop/iter var."""
kind: Literal["S", "R", "O"]
var: tir.Var
_dom: tir.PrimExpr
loop_rv: tir.schedule.LoopRV
def __init__(
self,
kind: Literal["S", "R", "O"],
var: tir.Var,
dom: tir.PrimExpr,
loop_rv: tir.schedule.LoopRV,
):
"""Construct an IterInfo object."""
self.kind = kind
self.var = var
self._dom = dom
self.loop_rv = loop_rv
@property
def dom(self) -> Union[int, tir.PrimExpr]:
"""The iteration domain of the loop."""
return int(self._dom) if isinstance(self._dom, tir.IntImm) else self._dom
def __str__(self) -> str:
return f'Iter("{self.kind}", {self.dom})'
def __repr__(self) -> str:
return str(self)
class BlockInfo:
"""Information about a TIR block."""
name: str
iters: List[IterInfo]
block_rv: tir.schedule.BlockRV
_reduction_block: bool
def __init__(
self,
name: str,
iters: List[IterInfo],
block_rv: tir.schedule.BlockRV,
reduction_block: bool = False,
):
"""Construct a BlockInfo object."""
self.name = name
self.block_rv = block_rv
self.iters = iters
self._reduction_block = reduction_block
def dom(self) -> List[Union[int, tir.PrimExpr]]:
"""The iteration domain of the block."""
return [i.dom for i in self.iters]
def dom_kind(self) -> str:
"""The iteration domain kind of the block, for example, SSSS, SSSR."""
return "".join(i.kind for i in self.iters)
def is_injective(self) -> bool:
"""Whether the block is injective, i.e. all its iteration domains are injective."""
return all(k == "S" for k in self.dom_kind())
def is_elementwise(self, sch: tir.Schedule) -> bool:
"""Whether the block is elementwise, i.e. trivial mapping between read/write region"""
def _check_unit_var_range(dom: ir.Range, var: tir.Var) -> bool:
return dom.min.same_as(var) and dom.extent == 1
if not self.is_injective():
return False
block = sch.get(self.block_rv)
if len(block.reads) != 1 or len(block.writes) != 1:
return False
r_region = block.reads[0].region
w_region = block.writes[0].region
if len(r_region) != len(w_region):
return False
for var, r_dom, w_dom in zip(block.iter_vars, r_region, w_region):
if not _check_unit_var_range(var, r_dom) or not _check_unit_var_range(var, w_dom):
return False
return True
def is_reduction(self) -> bool:
"""Whether the block is a reduction workload."""
# TODO(@junrushao): distinguish GEMV and reduction
return self._reduction_block
def is_gemv(self) -> bool:
"""Whether the block is a GEMV workload."""
raise NotImplementedError
def is_gemm(self) -> bool:
"""Whether the block is a GEMM workload."""
raise NotImplementedError
def __str__(self) -> str:
return f'BlockInfo("{self.name}", "{self.dom_kind()}", {self.dom()})'
def __repr__(self) -> str:
return str(self)
_normalize_prim_func = get_global_func("tir.schedule.NormalizePrimFunc")
def normalize_prim_func(sch: tir.Schedule) -> Optional[List[BlockInfo]]:
"""Normalize the primfunc to normal form"""
try:
result = _normalize_prim_func(sch)
if result is None:
return None
except Exception: # pylint: disable=broad-except
return None
def _iter_kind(i: tir.IterVar) -> str:
return {
tir.IterVar.DataPar: "S",
tir.IterVar.CommReduce: "R",
}.get(i.iter_type, "O")
blocks: List[BlockInfo] = []
for block, loops, iters, is_reduction in zip(*result):
blocks.append(
BlockInfo(
name=sch.get(block).name_hint,
iters=[
IterInfo(
kind=_iter_kind(iter), # type: ignore
var=iter.var,
dom=iter.dom,
loop_rv=loop,
) for loop, iter in zip(loops, iters)
],
block_rv=block,
reduction_block=is_reduction,
))
return blocks
def find_var_from_func(func, var: str):
for buffer in func.buffer_map.values():
for i in buffer.shape:
if isinstance(i, tir.Var) and i.name == var:
return i
return None
def check_func_with_dynamic(func):
for buffer in func.buffer_map.values():
for i in buffer.shape:
if isinstance(i, tir.Var):
return True
return False
def _assert_gpu_target(target: Target):
if "gpu" not in target.keys:
raise ValueError(f"Expect a GPU target, but got {target}")
def get_max_threads_per_block(target: Target) -> int:
_assert_gpu_target(target)
max_threads_per_block = None
for name in ["max_threads_per_block", "max_num_threads"]:
if max_threads_per_block is None:
max_threads_per_block = target.attrs.get(name, None)
if max_threads_per_block is None:
max_threads_per_block = 64
return int(max_threads_per_block)
def get_max_shared_memory_per_block(target: Target) -> int:
_assert_gpu_target(target)
max_shared_memory_per_block = target.attrs.get("max_shared_memory_per_block", None)
if max_shared_memory_per_block is None:
raise ValueError(
f"Cannot find `max_shared_memory_per_block` in {target}, please specify it manually")
return int(max_shared_memory_per_block)
def get_root_block(sch: Schedule, func_name: str = "main") -> BlockRV:
try:
block = sch.mod[func_name].body.block
except Exception:
raise ValueError(f"The function body is expected to be the root block, but got:\n"
f"{sch.mod[func_name].body}") from None
return sch.get_block(block.name_hint)
def collect_block_iter_vars_used_in_access_region(block: tir.Block,
region: List[ir.Range]) -> Set[tir.Var]:
"""Collect the block iter variables used in the access region of a buffer region."""
tir_vars = set()
for expr in region:
if expr.extent != 1:
continue
tir_vars |= collect_vars_used_in_prim_expr(expr.min)
tir_vars &= set(iter_var.var for iter_var in block.iter_vars)
return tir_vars
def collect_vars_used_in_prim_expr(expr: tir.PrimExpr) -> Set[tir.Var]:
"""Collect the variables used in the PrimExpr."""
tir_vars = set()
def _collect_tir_var(expr):
if isinstance(expr, tir.Var):
tir_vars.add(expr)
tir.stmt_functor.post_order_visit(expr, _collect_tir_var)
return tir_vars
def detect_dominant_read(block: tir.Block) -> tir.PrimExpr:
"""Detect the dominant read indices in the block."""
dominant_read = None
num_read_iters = -1
for buffer_region in block.reads:
tir_vars = collect_block_iter_vars_used_in_access_region(block, buffer_region.region)
if num_read_iters < len(tir_vars):
num_read_iters = len(tir_vars)
dominant_read = buffer_region
assert dominant_read is not None
(result,) = dominant_read.buffer.offset_of([e.min for e in dominant_read.region])
return result
def is_broadcast_epilogue(
sch: tir.Schedule,
block: tir.schedule.BlockRV,
epilogue: tir.schedule.BlockRV,
) -> bool:
"""Check if the epilogue block is a broadcast pattern"""
write_buffers = {r.buffer for r in sch.get(block).writes}
epilogue_iters = {i.var: i for i in sch.get(epilogue).iter_vars if i.dom != 1}
for buffer_region in sch.get(epilogue).reads:
if buffer_region.buffer not in write_buffers:
continue
tir_vars = collect_block_iter_vars_used_in_access_region(
sch.get(epilogue), buffer_region.region)
if len(tir_vars) < len(epilogue_iters):
return True
return False
def get_reduction_blocks(sch: tir.Schedule,
blocks: List[tir.schedule.BlockRV]) -> List[tir.schedule.BlockRV]:
# Get the main computation block
def is_reduction(block: BlockRV) -> bool:
block_stmt = sch.get(block)
iter_types = {iter_var.iter_type for iter_var in block_stmt.iter_vars}
return iter_types == {IterVar.CommReduce, IterVar.DataPar}
def is_spatial(block: BlockRV) -> bool:
block_stmt = sch.get(block)
iter_types = {iter_var.iter_type for iter_var in block_stmt.iter_vars}
return iter_types == {IterVar.DataPar}
# NOTE: We assume there is only one reduction block in the function
# all blocks are required to be spatial or reduction
if not all([is_reduction(block) or is_spatial(block) for block in blocks]):
return None
# There is only one reduction block
reduction_blocks = [block for block in blocks if is_reduction(block)]
if len(reduction_blocks) == 0:
return None
return reduction_blocks
def get_coalesced_veclen(block_stmt: tir.Block, target_bits: int = 128) -> int:
# gpu memory prefer 128 bits coalesced access (e.g. four banks)
# 128 bits
buffers: List[tir.Buffer] = []
for read in block_stmt.reads:
buffers.append(read.buffer)
for write in block_stmt.writes:
buffers.append(write.buffer)
# pick the dtype with the largest bits
max_dtype_bits: int = 0
for buffer in buffers:
max_dtype_bits = max(max_dtype_bits, DataType(buffer.dtype).bits)
return target_bits // max_dtype_bits
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from .arch_base import TileDevice
from .cuda import CUDA
from .cpu import CPU
from .cdna import CDNA
from typing import Union
from tvm.target import Target
def get_arch(target: Union[str, Target] = "cuda") -> TileDevice:
if isinstance(target, str):
target = Target(target)
if target.kind.name == "cuda":
return CUDA(target)
elif target.kind.name == "llvm":
return CPU(target)
elif target.kind.name == "hip":
return CDNA(target)
else:
raise ValueError(f"Unsupported target: {target.kind.name}")
def auto_infer_current_arch() -> TileDevice:
# TODO(lei): This is a temporary solution to infer the current architecture
# Can be replaced by a more sophisticated method in the future
return get_arch("cuda")
from .cpu import is_cpu_arch # noqa: F401
from .cuda import (
is_cuda_arch, # noqa: F401
is_volta_arch, # noqa: F401
is_ampere_arch, # noqa: F401
is_ada_arch, # noqa: F401
is_hopper_arch, # noqa: F401
is_tensorcore_supported_precision, # noqa: F401
has_mma_support, # noqa: F401
)
from .cdna import is_cdna_arch # noqa: F401
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from typing import List
class TileDevice:
"""
Represents the architecture of a computing device, capturing various hardware specifications.
"""
def __init__(self) -> None:
self.reg_cap: int = 0 # Register capacity: The amount of register memory available
self.smem_cap: int = 0 # Shared memory capacity: The amount of shared memory available
self.compute_max_core: int = 0 # The maximum number of computing cores
self.warp_size: int = (
0 # The size of a warp, a group of threads that execute instructions in lockstep
)
self.sm_partition: int = 0 # The number of streaming multiprocessor partitions
self.transaction_size: List[int] = [
0,
0,
] # The size of memory transactions, typically in bytes
self.max_smem_usage: int = 0 # The maximum shared memory usage allowed
self.bandwidth: List[int] = [
0,
0,
] # Bandwidth specifications, possibly including peak and sustained rates
self.platform: str = "unknown" # The platform or manufacturer of the device
self.compute_capability: str = (
"unknown" # The compute capability, indicating the feature set and performance level
)
self.l2_cache_size_bytes: int = 0
# the number of transaction size in bytes
self.transaction_size: List[int] = [0, 0] # in bytes
# bandwidth in MB/s, will be used for recommend basic tile size
self.bandwidth: List[int] = [0, 0]
def get_avaliable_tensorintrin_shapes(self):
raise NotImplementedError()
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import tvm
from tvm.target import Target
from .arch_base import TileDevice
from typing import List, Union
def is_cdna_arch(arch: TileDevice) -> bool:
return isinstance(arch, CDNA)
class CDNA(TileDevice):
def __init__(self, target: Union[Target, str]):
if isinstance(target, str):
target = tvm.target.Target(target)
self.target = target
device = tvm.runtime.rocm(0)
if not device.exist:
raise RuntimeError("Cannot find HIP device 0.")
self.device: tvm.runtime.Device = device
self.platform: str = "CDNA"
self.smem_cap = device.max_shared_memory_per_block
self.compute_max_core = device.multi_processor_count
self.warp_size = device.warp_size
self.compute_capability = device.compute_version.replace(".", "")
self.reg_cap: int = 32768
self.max_smem_usage: int = 2 * self.smem_cap
self.sm_partition: int = 4
self.l2_cache_size_bytes: int = target.l2_cache_size_bytes
self.transaction_size: List[int] = [32, 128] # in bytes
self.bandwidth: List[int] = [1300, 14000]
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import tvm
from tvm.target import Target
from .arch_base import TileDevice
def is_cpu_arch(arch: TileDevice) -> bool:
return isinstance(arch, CPU)
# For LLVM Backend, we do not provide the detailed information of the CPU
# As the LLVM backend do not required tuning, just maintain the consistency
class CPU(TileDevice):
def __init__(self, target: Target):
self.target = target
device = tvm.runtime.cpu(0)
if not device.exist:
raise RuntimeError("Cannot find cpu device 0.")
self.device: tvm.runtime.Device = device
self.platform: str = "CPU"
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import tvm
from tvm.target import Target
from .arch_base import TileDevice
from typing import List, Union
def check_sm_version(arch: str) -> int:
sm_version = arch.replace("sm_", "")
return int(sm_version) if sm_version.isdigit() else -1
def is_cuda_arch(arch: TileDevice) -> bool:
return isinstance(arch, CUDA)
def is_volta_arch(arch: TileDevice) -> bool:
conditions = [True]
conditions.append(is_cuda_arch(arch))
conditions.append(arch.sm_version >= 70)
conditions.append(arch.sm_version < 80)
return all(conditions)
def is_ampere_arch(arch: TileDevice) -> bool:
conditions = [True]
conditions.append(is_cuda_arch(arch))
conditions.append(arch.sm_version >= 80 and arch.sm_version < 89)
return all(conditions)
def is_ada_arch(arch: TileDevice) -> bool:
conditions = [True]
conditions.append(is_cuda_arch(arch))
conditions.append(arch.sm_version == 89)
return all(conditions)
def is_hopper_arch(arch: TileDevice) -> bool:
conditions = [True]
conditions.append(is_cuda_arch(arch))
conditions.append(arch.sm_version == 90)
return all(conditions)
def has_mma_support(arch: TileDevice) -> bool:
conditions = [True]
conditions.append(is_cuda_arch(arch))
conditions.append(arch.sm_version >= 80)
return all(conditions)
volta_tensorcore_supported = [
("float16", "float32"),
("float16", "float16"),
]
ampere_tensorcore_supported = [
("bfloat16", "float32"),
("float16", "float32"),
("float16", "float16"),
("int8", "int32"),
("int4", "int32"),
("int2", "int32"),
("int1", "int32"),
]
ada_tensorcore_supported = [
("bfloat16", "float32"),
("float16", "float32"),
("float16", "float16"),
("int8", "int32"),
("e5m2_float8", "float32"),
("e4m3_float8", "float32"),
]
hopper_tensorcore_supported = ada_tensorcore_supported
# TODO(lei): we should consider the dtype of the input a and b
# instead of assuming both a and b share the same dtype.
# As the tensorcore may supports e4m3_float8 * e5m2_float8
def is_tensorcore_supported_precision(in_dtype: str, accum_dtype: str, arch: TileDevice) -> bool:
if is_volta_arch(arch):
return (in_dtype, accum_dtype) in volta_tensorcore_supported
elif is_ampere_arch(arch):
return (in_dtype, accum_dtype) in ampere_tensorcore_supported
elif is_ada_arch(arch):
return (in_dtype, accum_dtype) in ada_tensorcore_supported
elif is_hopper_arch(arch):
return (in_dtype, accum_dtype) in hopper_tensorcore_supported
else:
raise ValueError(f"Unsupported architecture: {arch}")
class TensorInstruction(object):
def __init__(
self,
name: str,
shape: List[int],
):
self.name: str = name
# only hold the shape of M and N
self.shape: List[int] = shape
class CUDA(TileDevice):
def __init__(self, target: Union[Target, str]):
if isinstance(target, str):
target = tvm.target.Target(target)
self.target = target
self.sm_version = check_sm_version(self.target.arch)
device = tvm.runtime.cuda(0)
if not device.exist:
raise RuntimeError("Cannot find cuda device 0.")
self.device: tvm.runtime.Device = device
self.platform: str = "CUDA"
self.smem_cap = device.max_shared_memory_per_block
self.compute_max_core = device.multi_processor_count
self.warp_size = device.warp_size
self.compute_capability = device.compute_version.replace(".", "")
self.reg_cap: int = 65536
self.max_smem_usage: int = 2 * self.smem_cap
self.sm_partition: int = 4
self.l2_cache_size_bytes: int = target.l2_cache_size_bytes
# the number of transaction size in bytes
self.transaction_size: List[int] = [32, 128] # in bytes
# bandwidth in MB/s, will be used for recommend basic tile size
# TODO(lei): find some way to get the real bandwidth
# However, the ratio of bandwidth between different devices can
# be similar. The bandwidth can work for another devices as well.
self.bandwidth: List[int] = [750, 12080]
# get the available tensor instructions during runtime to avoid
# the dependency of the tensor intrinsics registration
self.available_tensor_instructions: List[TensorInstruction] = None
def get_avaliable_tensorintrin_shapes(self):
self.available_tensor_instructions = (
TensorInstruction("mma", [16, 16]),
TensorInstruction("wmma", [16, 16]),
)
return [t.shape for t in self.available_tensor_instructions]
def __repr__(self):
return f"CUDA({self.target})"
# Copyright 2018 The apache/tvm Authors. All Rights Reserved.
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
# Modifications Copyright (c) Microsoft.
# The code below is mostly copied from apache/tvm common_schedules.py in dlight.
"""Common schedule strategies for TIR."""
from typing import Callable, List
from tvm import tir
from .utils import retrieve_func_from_module
from .analysis import BlockInfo
def get_block(
sch: tir.Schedule,
blocks: List[BlockInfo],
name: str,
):
"""Get the target block from a schedule.
Parameters
----------
sch : tir.Schedule
The TIR schedule used to get target block.
name : str
The name of the target block.
Returns
-------
target_block : BlockRV
The target block.
"""
target_block: tir.BlockRV = None
for block_info in blocks:
block = block_info.block_rv
if sch.get(block).name_hint == name:
target_block = block
return target_block
def get_output_blocks(
sch: tir.Schedule,
blocks: List[BlockInfo],
):
"""Get the output blocks of a schedule.
Parameters
----------
sch : tir.Schedule
The TIR schedule used to get output blocks.
blocks : List[BlockInfo]
The blocks to be analyzed.
Returns
-------
output_blocks : List[BlockInfo]
The output blocks.
"""
# collect arguments buffer
func = retrieve_func_from_module(sch.mod)
args = list(func.buffer_map.values())
output_blocks = []
for block_info in blocks:
block = block_info.block_rv
for write in sch.get(block).writes:
if write.buffer in args:
output_blocks.append(block)
return output_blocks
def try_inline(
sch: tir.Schedule,
blocks: List[BlockInfo],
) -> List[BlockInfo]:
"""Try to inline as many blocks as possible, and return the remaining blocks.
Parameters
----------
sch : tir.Schedule
The TIR schedule used to inline blocks.
blocks : List[BlockInfo]
The blocks to be inlined.
Returns
-------
remaining : List[BlockInfo]
The remaining blocks that cannot be inlined.
"""
def _trial(func: Callable):
for i, block in enumerate(blocks):
try:
func(block.block_rv)
except Exception: # pylint: disable=bare-except
continue
return i
return None
while True:
i = _trial(sch.compute_inline)
if i is None:
i = _trial(sch.reverse_compute_inline)
if i is None:
break
blocks.pop(i)
return blocks
def try_inline_contiguous_spatial(
sch: tir.Schedule,
block_infos: List[BlockInfo],
) -> List[BlockInfo]:
"""Try to inline contiguous spatial blocks in a schedule
Parameters
----------
sch : tir.Schedule
The TIR schedule used to inline blocks.
block_infos : List[BlockInfo]
The blocks to be try.
Returns
-------
remaining : List[BlockInfo]
The remaining blocks that cannot be inlined.
"""
if block_infos is None:
return None
results = []
spatial_blocks = []
block: BlockInfo
for block in block_infos:
if block.is_injective():
spatial_blocks.append(block)
elif spatial_blocks:
results.extend(try_inline(sch, spatial_blocks))
results.append(block)
spatial_blocks = []
else:
results.append(block)
if spatial_blocks:
results.extend(try_inline(sch, spatial_blocks))
return results
This diff is collapsed.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from .node import PrimFuncNode # noqa: F401
from .rasterization import NoRasterization, Rasterization2DRow, Rasterization2DColumn # noqa: F401
from .hint import Hint # noqa: F401
from .policy import DefaultPolicy, TensorCorePolicy # noqa: F401
from ..arch import TileDevice, CUDA # noqa: F401
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Benefit For Carver Schedule"""
class Block:
def __init__(self, start, end, is_free):
self.start = start
self.end = end
self.is_free = is_free
def size(self) -> int:
return self.end - self.start
def merge(self, other):
assert self.is_free == other.is_free
self.start = min(self.start, other.start)
self.end = max(self.end, other.end)
def __repr__(self) -> str:
return "<Block offset={} size={}>".format(self.start, self.size())
class BestFit:
def __init__(self, align=32):
self.limit = 0
self.list = []
self.align = align
def malloc(self, size) -> Block:
size = (size + self.align - 1) // self.align * self.align
found = None
for block in self.list:
if block.is_free and block.size() >= size and not found or found.size() > block.size():
found = block
if found:
found.is_free = False
remain = found.size() - size
if remain != 0:
found.end -= remain
self.list.insert(
self.list.index(found) + 1, Block(found.end, found.end + remain, True))
return found
elif len(self.list) > 0 and self.list[-1].is_free:
add = size - self.list[-1].size()
self.list[-1].end += add
self.limit = self.list[-1].end
self.list[-1].is_free = False
return self.list[-1]
else:
block = Block(self.limit, self.limit + size, False)
self.list.append(block)
self.limit += size
return block
def free(self, block: Block) -> None:
assert not block.is_free
idx = self.list.index(block)
self.list[idx] = Block(block.start, block.end, True)
if idx + 1 < len(self.list) and self.list[idx + 1].is_free:
self.list[idx].merge(self.list[idx + 1])
self.list.pop(idx + 1)
if idx - 1 >= 0 and self.list[idx - 1].is_free:
self.list[idx].merge(self.list[idx - 1])
self.list.pop(idx - 1)
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Hint definition for schedule"""
from tvm import DataType
from typing import Dict, List, Tuple
from . import PrimFuncNode
import numpy as np
from .rasterization import *
class TensorCoreExtraConfig:
"""
This class is used to store extra information for tensorcore
"""
def __init__(
self,
AS_shape: Tuple[int],
BS_shape: Tuple[int],
AF_shape: Tuple[int],
BF_shape: Tuple[int],
tc_axis: Tuple[int],
) -> None:
self.AS_shape: Tuple[int] = AS_shape
self.BS_shape: Tuple[int] = BS_shape
self.AF_shape: Tuple[int] = AF_shape
self.BF_shape: Tuple[int] = BF_shape
self.tc_axis: Tuple[int] = tc_axis
class Stride:
"""
Manages stride information for a given axis of a tensor.
"""
def __init__(self, stride: int = 1, ax: int = -1) -> None:
# which axis to put stride on
self._ax: int = int(ax)
# the stride size of the axis
self._stride: int = int(stride)
@property
def ax(self) -> int:
return self._ax
@property
def stride(self) -> int:
return self._stride
def compute_strides_from_shape(self, shape: List[int]) -> List[int]:
ndim = len(shape)
strides = [1 for _ in shape]
for i in range(ndim - 2, -1, -1):
if i == self.ax:
strides[i] = self.stride
else:
strides[i] = int(strides[i + 1] * shape[i + 1])
return strides
def compute_elements_from_shape(self, shape: List[int]) -> int:
original_shape = np.prod(shape)
if not self.is_valid():
strided_elem = original_shape
else:
assert self.ax < len(shape)
strided_elem = np.prod(shape[0:self.ax + 1]) * self.stride
assert strided_elem >= original_shape
return int(strided_elem)
def is_valid(self) -> bool:
return self.ax >= 0
def __repr__(self) -> str:
return f"<Stride, {self._ax}, {self._stride}>"
class TileDict:
"""
Manages tiling information and configurations for computational tasks.
"""
def __init__(self, output_tile) -> None:
self.output_tile = output_tile
# schedule config
self.tile_map = {}
self.rstep_map = {}
self.cached_tensors_map = {}
self.output_strides_map = {}
self.tensor_strides_map = {}
# analysis
self.traffic = -1
self.smem_cost = -1
self.block_per_SM = -1
self.num_wave = -1
self.grid_size = -1
self.valid = True
def get_tile(self, func) -> List[int]:
return self.tile_map[func]
def get_rstep(self, func) -> Dict[str, int]:
return self.rstep_map
def __hash__(self) -> int:
return hash(tuple(self.output_tile))
class IntrinInfo:
"""
The information of tensorcore intrinsic related information
"""
def __init__(
self,
in_dtype: str,
out_dtype: str,
trans_b: bool,
input_transform_kind: int = 0,
weight_transform_kind: int = 0,
) -> None:
self.in_dtype = in_dtype
self.out_dtype = out_dtype
self.trans_a = False
self.trans_b = trans_b
self.input_transform_kind = input_transform_kind
self.weight_transform_kind = weight_transform_kind
def __repr__(self) -> str:
return f"<IntrinInfo, {self.in_dtype}, {self.out_dtype}, {self.trans_b}, {self.propagate_b}>"
def is_input_8bit(self) -> bool:
return DataType(self.in_dtype).bits == 8
@property
def smooth_a(self) -> bool:
return self.input_transform_kind >= 2
@property
def smooth_b(self) -> bool:
return self.weight_transform_kind >= 2
@property
def inter_transform_a(self) -> bool:
return self.input_transform_kind >= 1
@property
def inter_transform_b(self) -> bool:
return self.weight_transform_kind >= 1
class Hint(object):
"""
Central configuration class for managing various parameters of computational tasks.
"""
def __init__(self) -> None:
self.arch = None
self.use_tc = None # todo(lei): this should be renamed.
# Special axes tiling info
self.block = []
self.thread = []
# Special axes for MFMA
self.warp = []
# Reduce axes tiling info
self.rstep = []
self.reduce_thread = []
self.rasterization_plan = NoRasterization()
self.cached_tensors = []
self.output_strides = {}
self.schedule_stages = None
# Config for block reduction
self.block_reduction_depth = None # type: int
# TL Specific
# Split-K factor for SM waste optimization
self.split_k_factor: int = 1
# Experimental
self._raxis_order = []
self._step = []
self.vectorize: Dict[str, int] = {}
self.pipeline_stage = 1
self.use_async = False
self.opt_shapes: Dict[str, int] = {}
self.intrin_info = IntrinInfo("float16", "float16", True)
self.shared_scope: str = "shared"
self.pass_context: Dict = {}
def to_dict(self) -> Dict:
dic = {}
dic["block"] = self.block
if self.use_tc:
dic["warp"] = self.warp
else:
dic["thread"] = self.thread
dic["rstep"] = self.rstep
if np.prod(self.reduce_thread) > 1:
dic["reduce_thread"] = self.reduce_thread
if self.use_tc:
dic["use_tc"] = self.use_tc
if self.output_strides:
dic["strides"] = {}
for k, stride in self.output_strides.items():
if stride.is_valid():
dic["strides"][k] = stride
if len(dic["strides"]) == 0:
del dic["strides"]
if np.prod(self._step) > 1:
dic["step"] = self._step
if self._raxis_order != []:
dic["raxis_order"] = self._raxis_order
if self.vectorize != {}:
dic["vectorize"] = self.vectorize
if self.pipeline_stage != 1:
dic["pipeline_stage"] = self.pipeline_stage
if self.block_reduction_depth is not None:
dic["block_reduction_depth"] = self.block_reduction_depth
return dic
@classmethod
def from_dict(cls, dic: Dict) -> "Hint":
hint = cls()
for k, v in dic.items():
setattr(hint, k, v)
return hint
def tensorcore_legalization(self):
# only keep the last 2 axes for tensorcore
self.warp = self.warp[-2:]
self.block = self.block[-2:]
return self
@property
def raxis_order(self) -> List[int]:
if self._raxis_order != []:
return self._raxis_order
return list(range(len(self.rstep)))
@property
def step(self) -> List[int]:
if self._step != []:
return self._step
return [1 for _ in self.block]
def __repr__(self) -> str:
return str(self.to_dict())
def complete_config(self, node: PrimFuncNode):
# analysis pass context, for int8 mma, we should merge static shared memory
merge_static_smem = False
# int32 and float32 accum may take too much shared memory
if self.use_tc and self.intrin_info.out_dtype in ["float32", "int32"]:
merge_static_smem = True
# Always merge dynamic shared memory
if self.shared_scope == "shared.dyn":
merge_static_smem = True
self.pass_context = {"tir.merge_static_smem": merge_static_smem}
return self
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""PrimFunc Wrapper and Block information Analaysis"""
import tvm
from tvm import tir
from tvm.tir import IterVar, PrimFunc
from typing import Any, Dict, List, Tuple, Optional
from tvm.tir.schedule.schedule import BlockRV
import numpy as np
import functools
from ..analysis import BlockInfo, get_reduction_blocks
from .. import analysis
from .. import normalize_prim_func
from .shape_inference import get_analyzer_by_tir
def pre_order_traverse(block_analyzer, blocks, func):
visited = set()
def _traverse(block):
if block in visited:
return
visited.add(block)
for dep_blocks in block_analyzer.get_consumer_blocks(block):
_traverse(dep_blocks)
func(block)
for block in blocks:
_traverse(block)
class BlockAnalyzer(object):
def __init__(self, sch) -> None:
self.sch: tir.Schedule = sch
self.block_infos: List[BlockInfo] = normalize_prim_func(self.sch)
def get_block_name(self, block: BlockRV) -> str:
return self.sch.get(block).name_hint
def get_block_info(self, block: BlockRV) -> BlockInfo:
for block_info in self.block_infos:
if self.get_block_name(block) == block_info.name:
return block_info
return None
def get_spatial_axis(self, block: BlockRV) -> List[IterVar]:
block_info = self.get_block_info(block)
axis = []
for iter in block_info.iters:
if iter.kind == "S":
axis.append(iter)
return axis
def get_reduce_axis(self, block: BlockRV) -> List[IterVar]:
block_info = self.get_block_info(block)
raxis = []
for iter in block_info.iters:
if iter.kind == "R":
raxis.append(iter)
return raxis
def get_input_buffers(self, block: BlockRV) -> List[tir.Buffer]:
buffers = []
for read in self.sch.get(block).reads:
buffers.append(read.buffer)
return buffers
def get_output_buffers(self, block: BlockRV) -> List[tir.Buffer]:
buffers = []
for write in self.sch.get(block).writes:
buffers.append(write.buffer)
return buffers
def get_buffers(self, block: BlockRV) -> List[tir.Buffer]:
return self.get_input_buffers(block) + self.get_output_buffers(block)
def get_producer_blocks(self, block: BlockRV) -> List[BlockRV]:
return self.sch.get_producers(block)
def get_consumer_blocks(self, block: BlockRV) -> List[BlockRV]:
return self.sch.get_consumers(block)
class Node(object):
def __init__(self, tags: Optional[Dict] = None) -> None:
if tags is None:
tags = {}
self._dtypes = []
self._tag: Dict = {}
for tag in tags:
self.add_tag(tag, tags[tag])
def set_tag(self, k: str, v: Any = True) -> None:
self.add_tag(k, v)
def add_tag(self, k: str, v: Any = True) -> None:
self._tag[k] = v
def get_tag(self, k: str) -> Any:
if k not in self._tag:
return None
return self._tag[k]
class PrimFuncNode(Node):
def __init__(self, prim_func: PrimFunc, tags: Optional[Dict] = None) -> None:
super().__init__(tags)
self.prim_func = self._specialize_func(prim_func)
self.sch: tir.Schedule = tir.Schedule(self.prim_func)
self.block_analyzer: BlockAnalyzer = BlockAnalyzer(self.sch)
self.schedule_stages: List[BlockRV] = []
self.blocks: List[BlockRV] = []
self.output_blocks: List[BlockRV] = None
self.reduction_block: BlockRV = None
self.raxis = []
self.input_buffers = []
self.output_buffers = []
self.buffers = []
self.args = []
self._analysis_funcinfo()
self.ana = get_analyzer_by_tir(self.block_analyzer, self.blocks)
def _specialize_func(self, func: PrimFunc):
# Specialize the function to make it more friendly for analysis.
# set attrs
for k, v in func.attrs.items():
self.set_tag(k, v)
if self.get_tag("is_speclized"):
return func
opt_shapes = self.get_tag("opt_shapes")
if opt_shapes:
for name, shape in opt_shapes.items():
var = analysis.find_var_from_func(func, name)
if var is not None:
func = func.specialize({var: shape.astype(var.dtype)})
return func
def _analysis_funcinfo(self):
root_block = analysis.get_root_block(self.sch)
blocks = self.sch.get_child_blocks(root_block)
self.blocks = blocks
self.output_blocks = self.sch.get_output_blocks(root_block)
reduction_blocks = get_reduction_blocks(self.sch, blocks)
if reduction_blocks is None:
self.reduction_block = None
self.schedule_stages.append(*self.output_blocks)
else:
# analysis on the last reduction block
self.reduction_block = reduction_blocks[-1]
# set raxis
reduce_block_info = self.block_analyzer.get_block_info(self.reduction_block)
for iter in reduce_block_info.iters:
if iter.kind == "R":
self.raxis.append(iter)
self.schedule_stages.append(self.reduction_block)
# collect output buffers
for output_block in self.output_blocks:
for write in self.sch.get(output_block).writes:
if write not in self.output_buffers:
self.output_buffers.append(write.buffer)
for param in self.prim_func.params:
if param not in self.prim_func.buffer_map:
# in case of dynamic symbolic may in params
continue
buffer = self.prim_func.buffer_map[param]
if buffer not in self.output_buffers:
self.input_buffers.append(buffer)
self.args = self.input_buffers + self.output_buffers
self.buffers = [buffer for buffer in self.prim_func.buffer_map.values()]
# set dtype
self.set_dtype(tvm.DataType(self.output_buffers[0].dtype))
def get_opt_shape(self, name) -> int:
opt_shapes = self.get_tag("opt_shapes")
if opt_shapes is None:
return None
return opt_shapes[name]
def extent_wrapper(self, value) -> int:
if isinstance(value, tvm.tir.Var):
return self.get_opt_shape(value.name)
elif isinstance(value, tvm.tir.IntImm):
return int(value)
else:
return value
@functools.lru_cache()
def get_space_dim(self) -> List[int]:
dim_size = []
if self.reduction_block:
block_info = self.block_analyzer.get_block_info(self.reduction_block)
for iter in block_info.iters:
if iter.kind == "S":
if isinstance(iter.dom.extent, tvm.tir.IntImm):
dim_size.append(int(iter.dom.extent))
else:
assert isinstance(iter.dom.extent, tvm.tir.Var)
dim_size.append(self.get_opt_shape(iter.dom.extent.name))
else:
# assume outer stage has the same shape
loops = self.sch.get_loops(self.schedule_stages[0])
for loop in loops:
dim_size.append(int(self.sch.get(loop).extent))
return [int(x) for x in dim_size]
def set_dtype(self, dtype: tvm.DataType, id=0) -> None:
assert isinstance(dtype, tvm.DataType), type(dtype)
if dtype == tvm.DataType("bool"):
dtype = tvm.DataType("int8")
if len(self._dtypes) <= id:
self._dtypes.extend([None for _ in range(id - len(self._dtypes) + 1)])
elif self._dtypes[id] is not None:
assert self._dtypes[id] == dtype, (self._dtypes, dtype)
self._dtypes[id] = dtype
def get_dtype(self, id=0) -> tvm.DataType:
return self._dtypes[id]
def get_buffer_dtype(self, buffer: tir.Buffer) -> tvm.DataType:
return tvm.DataType(buffer.dtype)
def propagate(self, tile, rstep: Optional[Dict] = None, targets=None):
if rstep is None:
rstep = {}
shape = {
self.block_analyzer.get_output_buffers(block)[0].name: [
tvm.arith.ConstIntBound(0, val - 1) for val in tile
] for block in self.schedule_stages
}
return self.ana.infer(shape, rstep, targets)
def propagate_inputs(self, tile, rstep: Optional[Dict] = None) -> List[List[int]]:
if rstep is None:
rstep = {}
read_idx_offset = len(self.input_buffers)
targets = [t.name for t in self.args[:read_idx_offset]]
shapes, intermediate_bind = self.propagate(tile, rstep, targets)
results = []
for i, arg in enumerate(self.args[:read_idx_offset]):
if arg.name in intermediate_bind:
results.append(shapes[arg.name])
continue
# should not exceed original shape
trimmed_shape = [
self.extent_wrapper(i)
for i in list(map(min, zip(shapes[arg.name], self.input_buffers[i].shape)))
]
results.append(trimmed_shape)
return results
# Propagate inputs only on reduction block
def propagate_inputs_on_reduction(self, tile, rstep: Optional[Dict] = None) -> List[List[int]]:
if rstep is None:
rstep = {}
reduction_block = self.reduction_block
args = self.block_analyzer.get_input_buffers(reduction_block)
targets = [t.name for t in args]
shapes, intermediate_bind = self.propagate(tile, rstep, targets)
results = []
for i, arg in enumerate(args):
if arg.name in intermediate_bind:
results.append(shapes[arg.name])
continue
# should not exceed original shape
propagate_shape = shapes[arg.name]
buffer_shape = args[i].shape
if len(buffer_shape) > len(propagate_shape):
buffer_shape = buffer_shape[-len(propagate_shape):]
trimmed_shape = [
self.extent_wrapper(j) for j in list(map(min, zip(propagate_shape, buffer_shape)))
]
results.append(trimmed_shape)
return results
def propagate_outputs(self, tile, rstep: Optional[Dict] = None) -> List[List[int]]:
if rstep is None:
rstep = {}
read_idx_offset = len(self.input_buffers)
targets = [t.name for t in self.args[read_idx_offset:]]
shapes, _ = self.propagate(tile, rstep, targets)
results = []
for i, arg in enumerate(self.args[read_idx_offset:]):
# should not exceed original shape
trimmed_shape = list(map(min, zip(shapes[arg.name], self.input_buffers[i].shape)))
results.append(trimmed_shape)
return results
def propagate_reduction_inputs(self,
shape,
rstep: Optional[Dict] = None) -> Dict[str, List[int]]:
if rstep is None:
rstep = {}
if self.reduction_block is None:
return {}
targets = [b.name for b in self.block_analyzer.get_input_buffers(self.reduction_block)]
results, _ = self.propagate(shape, rstep, targets)
return results
def get_reduce_inputs_dtype(self):
if self.reduction_block is None:
return {}
return {
b.name: tvm.DataType(b.dtype)
for b in self.block_analyzer.get_input_buffers(self.reduction_block)
}
@functools.lru_cache()
def infer_tensorcore_axis(self) -> Tuple[int]:
# axis is fixed for one expression, so only inference and cached
assert self.get_tag("tensorcore_config")
C_ax_m, C_ax_n = self.get_tag("tensorcore_config")
wmma_m, wmma_n, wmma_k = [16, 16, 16] # just for testing, any number is ok
output_buffer_shape = (
self.block_analyzer.sch.get(self.reduction_block).writes[0].buffer.shape)
valid_region = []
for region in output_buffer_shape:
if region.value == 1:
continue
valid_region.append(region)
num_nvalid_regions = len(output_buffer_shape) - len(valid_region)
self.set_tag("num_nvalid_regions", num_nvalid_regions)
def get_cl_shapes(c_ax_m, c_ax_n, num_nvalid_regions):
spatial_dim = self.get_space_dim()
assert len(valid_region) == len(
spatial_dim), f" {valid_region} mismatch with {spatial_dim}"
cl_shapes = [1] * len(spatial_dim)
cl_shapes[c_ax_m - num_nvalid_regions] = wmma_m
cl_shapes[c_ax_n - num_nvalid_regions] = wmma_n
return cl_shapes
CL_shape = get_cl_shapes(C_ax_m, C_ax_n, num_nvalid_regions)
self.set_tag("tensorcore_config", [s - num_nvalid_regions for s in [C_ax_m, C_ax_n]])
shapes = self.propagate_reduction_inputs(CL_shape, {x.var.name: 1 for x in self.raxis})
A_deps, B_deps = shapes.values()
A_ax_m = A_deps.index(wmma_m)
B_ax_n = B_deps.index(wmma_n)
CL_shape = [1] * len(self.get_space_dim())
shapes = self.propagate_reduction_inputs(CL_shape, {x.var.name: wmma_k for x in self.raxis})
A_deps, B_deps = shapes.values()
A_ax_k = len(A_deps) - 1 - A_deps[::-1].index(wmma_k)
B_ax_k = len(B_deps) - 1 - B_deps[::-1].index(wmma_k)
tc_axis = (A_ax_m, A_ax_k, B_ax_k, B_ax_n, C_ax_m, C_ax_n)
return tc_axis
def footprint(self, shape, rstep, stride_map: Optional[Dict] = None) -> int:
if stride_map is None:
stride_map = {}
result = 0
shapes, _ = self.propagate(shape, rstep)
def is_broadcast_pattern(buffer, output_buffer):
return (buffer in self.args and
len(shapes[output_buffer.name]) > len(shapes[buffer.name]) and
np.prod(shapes[output_buffer.name]) > np.prod(shapes[buffer.name]))
def is_after_reduce_stage(block):
if not self.reduction_block:
return False
reduce_dependent_blocks = getattr(self, "reduce_dependent_blocks", None)
if reduce_dependent_blocks is None:
reduce_dependent_blocks = set()
pre_order_traverse(
self.block_analyzer,
[self.reduction_block],
lambda block: reduce_dependent_blocks.add(block),
)
self.reduce_dependent_blocks = reduce_dependent_blocks
return block not in reduce_dependent_blocks
# compute cached stages
cached_tensor = []
for block in self.blocks:
output_buffer = self.block_analyzer.get_output_buffers(block)[0]
for buffer in self.block_analyzer.get_input_buffers(block):
cache = buffer.name not in cached_tensor and (
is_broadcast_pattern(buffer, output_buffer) or
self.block_analyzer.get_block_info(block).is_reduction)
if not cache:
continue
cached_tensor.append(buffer.name)
if is_after_reduce_stage(block):
continue # cache after reduce op can often reuse buffer in reduce stage
if buffer.name in stride_map:
num_elem = stride_map[buffer.name].compute_elements_from_shape(
shapes[buffer.name])
else:
num_elem = np.prod(shapes[buffer.name])
buffer_len = num_elem * int((tvm.DataType(buffer.dtype).bits + 7) // 8)
buffer_len = (buffer_len + 31) // 32 * 32
result += buffer_len
return result, cached_tensor
def get_input_buffers(self) -> List[tir.Buffer]:
return self.block_analyzer.input_buffers
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from .default import DefaultPolicy # noqa: F401
from .tensorcore import TensorCorePolicy # noqa: F401
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from typing import List
import numpy as np
def get_all_factors(n: int) -> List[int]:
# Calculate the square root of n and round it up to the nearest integer
n0 = int(np.ceil(np.sqrt(n)))
# Find all divisors of n that are less than n0
val = np.where(n % np.arange(1, n0) == 0)[0] + 1
# If n is a perfect square, add the square root to the list of factors
mid = np.array([], dtype=int) if n0 * n0 != n else [n0]
# Combine the factors and their corresponding larger pair factors
return [int(x) for x in np.concatenate([val, mid, n // val[::-1]])]
def factorize(n: int) -> List[int]:
i = 2 # Start with the smallest prime number
result = []
# Iterate through numbers to find factors
while n > 1:
if n % i == 0: # If i is a factor of n
n //= i # Divide n by i and keep the integer part
result.append(i)
else:
i += 1 # Try the next number
return result
def coalesced_factor(subtensor: List[int], tensor: List[int]) -> int:
# If the last dimension of the subtensor and tensor differ, or subtensor has only one dimension
if subtensor[-1] != tensor[-1] or len(subtensor) == 1:
return subtensor[-1]
else:
# Recursively calculate the coalesced factor for the remaining dimensions
return subtensor[-1] * coalesced_factor(subtensor[:-1], tensor[:-1])
def coalesced_tensor_shape(subtensor: List[int], tensor: List[int], transaction_size: int) -> int:
# Calculate the total number of elements in the subtensor
bytes = int(np.prod(subtensor))
if bytes == 0:
return 0
# Calculate the coalesced factor for the subtensor
factor = int(coalesced_factor(subtensor, tensor))
# Compute the shape of the coalesced tensor
return transaction_size * bytes / min(transaction_size, factor)
This diff is collapsed.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Policy for tensorcore schedule"""
import tvm
from typing import Dict, List, Tuple, Optional
import numpy as np
import logging
from ...arch import TileDevice
from ..hint import Hint, Stride, TileDict, IntrinInfo
from ..node import PrimFuncNode
from .common import coalesced_factor, factorize, get_all_factors
from .default import DefaultPolicy
from ..rasterization import NoRasterization, Rasterization2DColumn
logger = logging.getLogger(__name__)
class TensorCorePolicy(DefaultPolicy):
def __init__(self,
func: tvm.tir.PrimFunc,
arch: TileDevice,
tags: Optional[Dict] = None) -> None:
super().__init__(func, arch, tags)
# this is the trick for wmma.
# However, for int8 mma, the wmma_k should be 32.
self.wmma_k = 16
self.pipeline_stage: int = 1
self.use_async_copy: bool = False
self.block_reduction_depth: Optional[int] = None
self._legalize_info()
def _legalize_info(self):
pipleline_stage = self.prim_func_node.get_tag("pipeline_stage")
if pipleline_stage:
self.pipeline_stage = pipleline_stage
else:
if self.arch.compute_capability == "sm_80":
self.pipeline_stage = 2
else:
self.pipeline_stage = 1
use_async_copy = self.prim_func_node.get_tag("use_async_copy")
if use_async_copy:
self.use_async_copy = use_async_copy
else:
if self.arch.compute_capability == "sm_80":
self.use_async_copy = True
else:
self.use_async_copy = False
# TODO: block reduction depth is not used for now.
# As there still exists some performance issues for block reduction.
block_reduction_depth = self.prim_func_node.get_tag("block_reduction_depth")
if block_reduction_depth:
self.block_reduction_depth = block_reduction_depth
def _compute_tc_strides(
self,
node: PrimFuncNode,
tile: List[int],
rstep: Optional[Dict[str, int]] = None,
) -> Tuple[Stride, Stride, Stride]:
if rstep is None:
rstep = {}
# strides was used for shared memory padding. which is necessary for avoiding
# shared memory load bank conflict when we do not applying tensorcore layout.
shapes = node.propagate_reduction_inputs(tile, rstep)
AS_shape, BS_shape = shapes.values()
CS_shape = tile
A_ax_m, A_ax_k, B_ax_k, B_ax_n, C_ax_m, C_ax_n = node.infer_tensorcore_axis()
# applying strides
# TODO(leiwang1999): offset should be dynamically set. we can use tag -> enable_offset to control this option..
offset = 8
A_high_ax = min(A_ax_m, A_ax_k)
B_high_ax = min(B_ax_n, B_ax_k)
C_high_ax = min(C_ax_m, C_ax_n)
A_stride = Stride(stride=np.prod(AS_shape[A_high_ax + 1:]) + offset, ax=A_high_ax)
B_stride = Stride(stride=np.prod(BS_shape[B_high_ax + 1:]) + offset, ax=B_high_ax)
C_stride = Stride(stride=np.prod(CS_shape[C_high_ax + 1:]) + offset, ax=C_high_ax)
return A_stride, B_stride, C_stride
def infer_node_smem_usage(self, td: TileDict, node: PrimFuncNode):
value, cached_tensors = super().infer_node_smem_usage(td, node)
value *= self.pipeline_stage
return value, cached_tensors
def _assign_reduce_step(self, node):
if not node.get_tag("tensorcore_config"):
return super()._assign_reduce_step(node)
# get reduce input size
target_transaction = self.arch.transaction_size[0] * 2
# 512 bytes // type bits
reduce_input_dtype = node.get_buffer_dtype(
node.block_analyzer.get_input_buffers(node.reduction_block)[0])
basic = (target_transaction * 8) // reduce_input_dtype.bits
result = {}
for iter_info in node.raxis:
iter_name = iter_info.var.name
iter_dom = iter_info.dom.extent
if iter_dom % 16 > 0:
result[iter_name] = (16 if iter_dom < basic else basic) # for the case of padding
elif iter_dom % basic == 0:
result[iter_name] = basic
else:
return super()._assign_reduce_step(node)
return result
def _expand_reduce_axis(self, td: TileDict):
# For tensorcore program, if we got a small tilesize, we should consider expand the reduce axis
# to improve compute efficiency.
def _check_small_tile(td: TileDict):
minimal_threadhold = 32
for node in self.ordered_nodes:
tile = td.get_tile(node)
if any([t <= minimal_threadhold for t in tile]):
return True
return False
if _check_small_tile(td):
smem_limit = min(self.arch.max_smem_usage // td.block_per_SM, self.arch.smem_cap)
rstep_map = td.rstep_map.copy()
def _optimize(node, rstep):
all_steps = self.get_node_reduce_step_candidates(node)
# todo(lei): optimize the all_steps enlarge policy to be a multiple of the original all_steps[k]
for k in all_steps:
all_steps[k] = list(filter(lambda x: x % rstep[k] == 0, all_steps[k]))
if any([v == [] for v in all_steps.values()]):
return rstep
def _shared_memory_usage(td: TileDict):
return node.footprint(td.output_tile, new_rstep_map,
td.tensor_strides_map[node])
def _score(rstep_id):
rstep = {
k.var.name: all_steps[k.var.name][rstep_id[k.var.name]] for k in node.raxis
}
score = 0
shape = node.propagate_inputs_on_reduction(td.get_tile(node), rstep=rstep)
input_buffers = node.block_analyzer.get_input_buffers(node.reduction_block)
for i, input_buffer in enumerate(input_buffers):
score += coalesced_factor(shape[i], input_buffer.shape)
return score
def _enlarge(rstep_id):
candidates = []
for ax in rstep_id:
if rstep_id[ax] + 1 == len(all_steps[ax]):
continue
r = rstep_id.copy()
r[ax] += 1
candidates.append((r, _score(r)))
if len(candidates) == 0:
return None
return max(candidates, key=lambda x: x[1])[0]
cur_rstep_id = {
k.var.name: all_steps[k.var.name].index(rstep[k.var.name]) for k in node.raxis
}
new_rstep_map = rstep_map.copy()
while True:
new_rstep_id = _enlarge(cur_rstep_id)
if new_rstep_id is None:
break
new_rstep_map = {
k.var.name: all_steps[k.var.name][new_rstep_id[k.var.name]]
for k in node.raxis
}
old_rstep_map = td.rstep_map
td.rstep_map = new_rstep_map
smem_usage, _ = _shared_memory_usage(td)
td.rstep_map = old_rstep_map
if smem_usage > smem_limit:
break
else:
cur_rstep_id = new_rstep_id
rstep = {
k.var.name: all_steps[k.var.name][cur_rstep_id[k.var.name]] for k in node.raxis
}
return rstep
for node in self.ordered_nodes:
if len(node.raxis) > 0:
rstep = _optimize(node, rstep_map)
rstep_map = rstep
td.rstep_map = rstep_map
td.smem_cost, td.cached_tensors_map = self._compute_shared_memory_usage(td)
if self.block_reduction_depth is not None:
def _expand_with_tags(rstep):
new_rstep = {k: v * self.block_reduction_depth for k, v in rstep.items()}
return new_rstep
rstep_map = td.rstep_map.copy()
for node in self.ordered_nodes:
if len(node.raxis) > 0:
rstep = _expand_with_tags(rstep_map)
rstep_map = rstep
td.rstep_map = rstep_map
return
def get_node_reduce_step_candidates(self, node):
if not node.get_tag("tensorcore_config"):
return super().get_node_reduce_step_candidates(node)
else:
# must be a a multiple of wmma_k
return {
k.var.name: [
x * self.wmma_k for x in get_all_factors(int(k.dom.extent) // self.wmma_k)
] for k in node.raxis
}
def check_tile_shape_isvalid(self, td: TileDict):
for node in self.ordered_nodes:
if node.get_tag("tensorcore_config"):
ax_m, ax_n = node.get_tag("tensorcore_config")
block_m, block_n = (
td.tile_map[node][ax_m],
td.tile_map[node][ax_n],
)
# check the tile size is valid
wmma_invalid = [
block_m < wmma_m or block_n < wmma_n
for wmma_m, wmma_n in self.arch.get_avaliable_tensorintrin_shapes()
]
if all(wmma_invalid):
return False
if any([y % x for x, y in zip(td.tile_map[node], node.get_space_dim())]):
return False
return super().check_tile_shape_isvalid(td)
def _can_implement_layout(self, node: PrimFuncNode, td: TileDict):
# Not implemented yet
# This function is used to check whether we can implement swizzling
# layout under this tile config
return False
def compute_node_stride_map(self, node: PrimFuncNode, td: TileDict):
if not node.get_tag("tensorcore_config"):
return super().compute_node_stride_map(node, td)
use_layout = self._can_implement_layout(node, td)
AS_stride, BS_stride, C_stride = self._compute_tc_strides(node, td.get_tile(node),
td.get_rstep(node))
A_stride, B_stride, _ = self._compute_tc_strides(node, td.get_tile(node))
tensor_strides = {}
output_strides = {
int(i + len(node.input_buffers)): Stride() for i, _ in enumerate(node.output_buffers)
}
tensor_strides = {}
# when connected to shared input, should use full stride without rstep
for i, (_, _) in enumerate(zip([AS_stride, BS_stride], [A_stride, B_stride])):
if use_layout:
continue
_ = node.block_analyzer.get_input_buffers(node.reduction_block)[i].name
# TODO(lei): should dig further for shared memory connection case.
return output_strides, tensor_strides
def _assign_block_size(self, node: PrimFuncNode, td: TileDict, block_size: int):
if not node.get_tag("tensorcore_config"):
return super()._assign_block_size(node, td, block_size)
ax_m, ax_n = node.get_tag("tensorcore_config")
if block_size % self.arch.warp_size != 0:
return None
tile, rsteps = td.get_tile(node), td.get_rstep(node)
warps = block_size // self.arch.warp_size
ndim = len(tile)
wmma = self.arch.get_avaliable_tensorintrin_shapes()[-1]
wmma_tile = [1 for _ in range(ndim)]
wmma_tile[ax_m] = wmma[0]
wmma_tile[ax_n] = wmma[1]
space = [tile[i] // wmma_tile[i] for i in range(ndim)]
if tile[ax_m] < wmma_tile[ax_m] or tile[ax_n] < wmma_tile[ax_n]:
# allow pad, otherwise, we can not get a valid tile shape
return None
factors = factorize(np.prod(space) // warps)
def _score(node, thread): # small is better
score = 0
block_tile = [int(np.ceil(tile[i] / thread[i])) for i in range(ndim)]
shape = node.propagate_inputs_on_reduction(block_tile)
input_buffers = node.block_analyzer.get_input_buffers(node.reduction_block)
for i, _ in enumerate(input_buffers):
score += np.prod(shape[i]) / self.arch.bandwidth[1]
return score
warp_tile = wmma_tile.copy()
for factor in reversed(factors):
score_map = {}
for i in range(ndim):
if tile[i] % (warp_tile[i] * factor) != 0:
continue
warp_tile[i] *= factor
score_map[i] = (_score(node, warp_tile), i)
warp_tile[i] //= factor
if len(score_map) == 0:
return None
dim_order = sorted(score_map.keys(), key=lambda x: score_map[x])
warp_tile[dim_order[0]] *= factor
codegen_dict = Hint()
codegen_dict.block = tile
codegen_dict.warp = warp_tile
codegen_dict.use_tc = True
codegen_dict.pipeline_stage = self.pipeline_stage
codegen_dict.block_reduction_depth = self.block_reduction_depth
codegen_dict.use_async = self.use_async_copy
codegen_dict.rstep = [int(rsteps[ax.var.name]) for ax in node.raxis]
codegen_dict.cached_tensors = td.cached_tensors_map[node]
codegen_dict.rasterization_plan = self.plan_rasterization(td)
intrin_info = node.get_tag("intrin_info")
if intrin_info:
codegen_dict.intrin_info = IntrinInfo(**intrin_info)
if intrin_info["out_dtype"] in ["float32"]:
codegen_dict.shared_scope = "shared.dyn"
# smem capacity
# TODO: This is a dummy mul which avoid reusing some shared memory.
# Should be removed in the future.
if td.smem_cost > (self.arch.smem_cap):
# Tile Dict: {td.output_tile} Shared memory exceeds the static capacity
# use dynamic shared memory.
codegen_dict.shared_scope = "shared.dyn"
codegen_dict.shared_scope = "shared.dyn"
codegen_dict.complete_config(node)
codegen_dict.vectorize = self._plan_vectorize(self.prim_func_node, td, block_size)
codegen_dict.arch = self.arch
codegen_dict.opt_shapes = self.prim_func_node.get_tag("opt_shapes")
codegen_dict.tensorcore_legalization()
return codegen_dict
def plan_rasterization(self, td: TileDict):
conditions = []
# only support single node for now
conditions.append(len(self.ordered_nodes) > 1)
# only on Ampere+ arch
conditions.append(self.arch.compute_capability < "80")
def _check_memory_size():
overall_gmem_size_in_bytes: int = 0
for node in self.ordered_nodes:
for buffer in node.input_buffers:
overall_gmem_size_in_bytes += (
int(np.prod(buffer.shape)) * tvm.DataType(buffer.dtype).bits // 8)
return overall_gmem_size_in_bytes < self.arch.l2_cache_size_bytes
conditions.append(_check_memory_size())
if any(conditions):
return NoRasterization()
# otherwise, simply provide a block rasterization factor
raster_factor = int(self.arch.compute_max_core**0.5)
return Rasterization2DColumn(raster_factor)
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Rasteration Plan For L2 Cache Locality"""
from typing import List
class Rasterization:
panel_width_ = None
def __init__(self) -> None:
pass
def get_code(self) -> List[str]:
raise NotImplementedError()
@property
def panel_width(self):
assert self.panel_width_ is not None
return self.panel_width_
class NoRasterization(Rasterization):
def __init__(self) -> None:
super().__init__()
def __repr__(self) -> str:
return "<NoRasterization>"
def get_code(self) -> List[str]:
return []
class Rasterization2DRow(Rasterization):
"""
Rasterization by Row, each Row line width is panel_width
_________
_________|
|_________
__________|
"""
def __init__(self, panel_width=4) -> None:
super().__init__()
self.panel_width_ = panel_width
def __repr__(self) -> str:
return f"<Rasterization2DRow({self.panel_width_})>"
def get_code(self) -> List[str]:
raise NotImplementedError()
class Rasterization2DColumn(Rasterization):
"""
Rasterization by Column, each column line width is panel_width
_
| | | |
| | | |
|_| |_|
"""
def __init__(self, panel_width=4) -> None:
super().__init__()
self.panel_width_ = panel_width
def __repr__(self) -> str:
return f"<Rasterization2DColumn({self.panel_width_})>"
def get_device_function(self) -> str:
return """
__device__ __inline__ dim3 rasterization2DColumn(const int panel_width) {
const auto baseBlockIdx = blockIdx.x + gridDim.x *blockIdx.y;
const auto totalPanel = (gridDim.x * gridDim.y +panel_width * gridDim.x - 1) / (panel_width * gridDim.x);
const auto totalBlock = gridDim.x * gridDim.y;
const auto panelIdx = baseBlockIdx / (panel_width *gridDim.x);
const auto strideLd = panelIdx + 1 < totalPanel ?panel_width : (totalBlock - panelIdx * (panel_width *gridDim.x)) / gridDim.x;
const auto bx = (panelIdx & 1) ? gridDim.x -(baseBlockIdx - panelIdx * panel_width * gridDim.x) /strideLd - 1 : (baseBlockIdx - panelIdx * panel_width *gridDim.x) / strideLd;
const auto by = (baseBlockIdx - panelIdx * panel_width *gridDim.x) % strideLd + panelIdx * panel_width;
const auto bz = blockIdx.z;
dim3 blockIdx(bx, by, bz);
return blockIdx;
}
"""
def get_code(self, panel_width: int = None) -> List[str]:
if panel_width is None:
panel_width = self.panel_width_
return [
self.get_device_function(),
"const dim3 blockIdx = rasterization2DColumn({});\n".format(panel_width),
]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment